diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 447164a9b..4697acf20 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.21-bullseye +FROM golang:1.23-bullseye RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ && apt-get -y install --no-install-recommends\ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ef883d31a..97aad75ad 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -7,7 +7,7 @@ "features": { "ghcr.io/devcontainers/features/docker-in-docker:2": {}, "ghcr.io/devcontainers/features/go:1": { - "version": "1.21" + "version": "1.23" } }, "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 88db8c5e8..2dbeb106a 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -21,6 +21,7 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false - name: Checkout code uses: actions/checkout@v4 @@ -28,8 +29,9 @@ jobs: uses: actions/cache@v4 with: path: ~/go/pkg/mod - key: macos-go-${{ hashFiles('**/go.sum') }} + key: macos-gotest-${{ hashFiles('**/go.sum') }} restore-keys: | + macos-gotest- macos-go- - name: Install libpcap @@ -42,4 +44,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ef6672002..f1e7c299d 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -11,31 +11,115 @@ concurrency: cancel-in-progress: true jobs: + build-cache: + runs-on: ubuntu-22.04 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + id: cache + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: steps.cache.outputs.cache-hit != 'true' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Build client + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 go build . + + - name: Build client 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 . + + - name: Build management + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 go build . + + - name: Build management 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 . + + - name: Build signal + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 go build . + + - name: Build signal 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 . + + - name: Build relay + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 go build . + + - name: Build relay 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . + test: + needs: [build-cache] strategy: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres'] runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 with: go-version: "1.23.x" - - - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -50,27 +134,265 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + + test_management: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres', 'mysql' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Login to Docker hub + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + + benchmark: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres', 'mysql' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Login to Docker hub + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + + api_benchmark: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres', 'mysql' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Login to Docker hub + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m $(go list -tags=benchmark ./... | grep /management) + + api_integration_test: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres'] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m $(go list -tags=integration ./... | grep /management) test_client_on_docker: + needs: [ build-cache ] runs-on: ubuntu-20.04 steps: - name: Install Go uses: actions/setup-go@v5 with: go-version: "1.23.x" - - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d378bec3f..3a3c47052 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,6 +24,23 @@ jobs: id: go with: go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest- + ${{ runner.os }}-go- - name: Download wintun uses: carlosperate/download-file-action@v2 @@ -42,11 +59,13 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index dacb1922b..6705a34ec 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd + ignore_words_list: erro,clienta,hastable,iif,groupd,testin skip: go.mod,go.sum only_warn: 1 golangci: @@ -46,7 +46,7 @@ jobs: if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - args: --timeout=12m + args: --timeout=12m --out-format colored-line-number diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 14e383a27..183cdb02c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.16" + SIGN_PIPE_VER: "v0.0.17" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index da3ec746a..5a3c6c22e 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - store: [ 'sqlite', 'postgres' ] + store: [ 'sqlite', 'postgres', 'mysql' ] services: postgres: image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }} @@ -34,6 +34,19 @@ jobs: --health-timeout 5s ports: - 5432:5432 + mysql: + image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }} + env: + MYSQL_USER: netbird + MYSQL_PASSWORD: mysql + MYSQL_ROOT_PASSWORD: mysqlroot + MYSQL_DATABASE: netbird + options: >- + --health-cmd "mysqladmin ping --silent" + --health-interval 10s + --health-timeout 5s + ports: + - 3306:3306 steps: - name: Set Database Connection String run: | @@ -42,6 +55,11 @@ jobs: else echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV fi + if [ "${{ matrix.store }}" == "mysql" ]; then + echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV + else + echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV + fi - name: Install jq run: sudo apt-get install -y jq @@ -84,6 +102,7 @@ jobs: CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }} + NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }} CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false - name: check values @@ -112,6 +131,7 @@ jobs: CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$' + NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$' CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" @@ -149,6 +169,7 @@ jobs: grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP + grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" # check relay values grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml diff --git a/.goreleaser.yaml b/.goreleaser.yaml index e718b3fcd..d6479763e 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -179,6 +179,51 @@ dockers: - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io" + + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + ids: + - netbird + goarch: amd64 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + ids: + - netbird + goarch: arm64 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/arm64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm + ids: + - netbird + goarch: arm + goarm: 6 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/arm" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: - netbirdio/relay:{{ .Version }}-amd64 ids: @@ -377,6 +422,18 @@ docker_manifests: - netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-amd64 + - name_template: netbirdio/netbird:{{ .Version }}-rootless + image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + - netbirdio/netbird:{{ .Version }}-rootless-arm + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + + - name_template: netbirdio/netbird:rootless-latest + image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + - netbirdio/netbird:{{ .Version }}-rootless-arm + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + - name_template: netbirdio/relay:{{ .Version }} image_templates: - netbirdio/relay:{{ .Version }}-arm64v8 diff --git a/README.md b/README.md index 270c9ad87..0537710e9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,3 @@ -

- :hatching_chick: New Release! Device Posture Checks. - - Learn more - -

-

@@ -17,8 +10,12 @@
- + + +
+ +

@@ -30,7 +27,7 @@
See Documentation
- Join our Slack channel + Join our Slack channel
diff --git a/client/Dockerfile b/client/Dockerfile index b9f7c1355..2f5ff14ae 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.20 +FROM alpine:3.21.0 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless new file mode 100644 index 000000000..62bcaf964 --- /dev/null +++ b/client/Dockerfile-rootless @@ -0,0 +1,16 @@ +FROM alpine:3.21.0 + +COPY netbird /usr/local/bin/netbird + +RUN apk add --no-cache ca-certificates \ + && adduser -D -h /var/lib/netbird netbird +WORKDIR /var/lib/netbird +USER netbird:netbird + +ENV NB_FOREGROUND_MODE=true +ENV NB_USE_NETSTACK_MODE=true +ENV NB_CONFIG=config.json +ENV NB_DAEMON_ADDR=unix://netbird.sock +ENV NB_DISABLE_DNS=true + +ENTRYPOINT [ "/usr/local/bin/netbird", "up" ] diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 7ebe0442d..89552724a 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -12,6 +12,8 @@ import ( "strings" ) +const anonTLD = ".domain" + type Anonymizer struct { ipAnonymizer map[netip.Addr]netip.Addr domainAnonymizer map[string]string @@ -19,6 +21,8 @@ type Anonymizer struct { currentAnonIPv6 netip.Addr startAnonIPv4 netip.Addr startAnonIPv6 netip.Addr + + domainKeyRegex *regexp.Regexp } func DefaultAddresses() (netip.Addr, netip.Addr) { @@ -34,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { currentAnonIPv6: startIPv6, startAnonIPv4: startIPv4, startAnonIPv6: startIPv6, + + domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`), } } @@ -83,29 +89,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string { } func (a *Anonymizer) AnonymizeDomain(domain string) string { - if strings.HasSuffix(domain, "netbird.io") || - strings.HasSuffix(domain, "netbird.selfhosted") || - strings.HasSuffix(domain, "netbird.cloud") || - strings.HasSuffix(domain, "netbird.stage") || - strings.HasSuffix(domain, ".domain") { + baseDomain := domain + hasDot := strings.HasSuffix(domain, ".") + if hasDot { + baseDomain = domain[:len(domain)-1] + } + + if strings.HasSuffix(baseDomain, "netbird.io") || + strings.HasSuffix(baseDomain, "netbird.selfhosted") || + strings.HasSuffix(baseDomain, "netbird.cloud") || + strings.HasSuffix(baseDomain, "netbird.stage") || + strings.HasSuffix(baseDomain, anonTLD) { return domain } - parts := strings.Split(domain, ".") + parts := strings.Split(baseDomain, ".") if len(parts) < 2 { return domain } - baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1] + baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1] - anonymized, ok := a.domainAnonymizer[baseDomain] + anonymized, ok := a.domainAnonymizer[baseForLookup] if !ok { - anonymizedBase := "anon-" + generateRandomString(5) + ".domain" - a.domainAnonymizer[baseDomain] = anonymizedBase + anonymizedBase := "anon-" + generateRandomString(5) + anonTLD + a.domainAnonymizer[baseForLookup] = anonymizedBase anonymized = anonymizedBase } - return strings.Replace(domain, baseDomain, anonymized, 1) + result := strings.Replace(baseDomain, baseForLookup, anonymized, 1) + if hasDot { + result += "." + } + return result } func (a *Anonymizer) AnonymizeURI(uri string) string { @@ -152,27 +168,22 @@ func (a *Anonymizer) AnonymizeString(str string) string { return str } -// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. +// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes. func (a *Anonymizer) AnonymizeSchemeURI(text string) string { - re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) + re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`) return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } -// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string. func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { - domainPattern := `dns\.Question{Name:"([^"]+)",` - domainRegex := regexp.MustCompile(domainPattern) - - return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string { - parts := strings.Split(match, `"`) + return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string { + parts := strings.SplitN(match, "=", 2) if len(parts) >= 2 { domain := parts[1] - if strings.HasSuffix(domain, ".domain") { + if strings.HasSuffix(domain, anonTLD) { return match } - randomDomain := generateRandomString(10) + ".domain" - return strings.Replace(match, domain, randomDomain, 1) + return "domain=" + a.AnonymizeDomain(domain) } return match }) diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index e660749ec..ff2e48869 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) { func TestAnonymizeDNSLogLine(t *testing.T) { anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) - testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}` + tests := []struct { + name string + input string + original string + expect string + }{ + { + name: "Basic domain with trailing content", + input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123", + original: "example.com", + expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`, + }, + { + name: "Domain with trailing dot", + input: "domain=example.com. processing request with status=pending", + original: "example.com", + expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`, + }, + { + name: "Multiple domains in log", + input: "forward domain=first.com status=ok, redirect to domain=second.com port=443", + original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately + expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`, + }, + { + name: "Already anonymized domain", + input: "got request domain=anon-xyz123.domain from=client1 to=server2", + original: "", // nothing should be anonymized + expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`, + }, + { + name: "Subdomain with trailing dot", + input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp", + original: "example.com", + expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`, + }, + { + name: "Handler chain pattern log", + input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100", + original: "example.com", + expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`, + }, + } - result := anonymizer.AnonymizeDNSLogLine(testLog) - require.NotEqual(t, testLog, result) - assert.NotContains(t, result, "example.com") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeDNSLogLine(tc.input) + if tc.original != "" { + assert.NotContains(t, result, tc.original) + } + assert.Regexp(t, tc.expect, result) + }) + } } func TestAnonymizeDomain(t *testing.T) { @@ -67,18 +115,36 @@ func TestAnonymizeDomain(t *testing.T) { `^anon-[a-zA-Z0-9]+\.domain$`, true, }, + { + "Domain with Trailing Dot", + "example.com.", + `^anon-[a-zA-Z0-9]+\.domain.$`, + true, + }, { "Subdomain", "sub.example.com", `^sub\.anon-[a-zA-Z0-9]+\.domain$`, true, }, + { + "Subdomain with Trailing Dot", + "sub.example.com.", + `^sub\.anon-[a-zA-Z0-9]+\.domain.$`, + true, + }, { "Protected Domain", "netbird.io", `^netbird\.io$`, false, }, + { + "Protected Domain with Trailing Dot", + "netbird.io.", + `^netbird\.io.$`, + false, + }, } for _, tc := range tests { @@ -140,8 +206,16 @@ func TestAnonymizeSchemeURI(t *testing.T) { expect string }{ {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, + {"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, + {"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`}, + {"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`}, + {"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, + {"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`}, + {"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`}, + {"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`}, + {"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`}, } for _, tc := range tests { diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 9abd2039d..c7ab87b47 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -3,6 +3,7 @@ package cmd import ( "context" "fmt" + "strings" "time" log "github.com/sirupsen/logrus" @@ -61,6 +62,15 @@ var forCmd = &cobra.Command{ RunE: runForDuration, } +var persistenceCmd = &cobra.Command{ + Use: "persistence [on|off]", + Short: "Set network map memory persistence", + Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`, + Example: " netbird debug persistence on", + Args: cobra.ExactArgs(1), + RunE: setNetworkMapPersistence, +} + func debugBundle(cmd *cobra.Command, _ []string) error { conn, err := getClient(cmd) if err != nil { @@ -171,6 +181,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { time.Sleep(1 * time.Second) + // Enable network map persistence before bringing the service up + if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: true, + }); err != nil { + return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message()) + } + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) } @@ -200,6 +217,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + // Disable network map persistence after creating the debug bundle + if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: false, + }); err != nil { + return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message()) + } + if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) @@ -219,6 +243,34 @@ func runForDuration(cmd *cobra.Command, args []string) error { return nil } +func setNetworkMapPersistence(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + persistence := strings.ToLower(args[0]) + if persistence != "on" && persistence != "off" { + return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0]) + } + + client := proto.NewDaemonServiceClient(conn) + _, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: persistence == "on", + }) + if err != nil { + return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message()) + } + + cmd.Printf("Network map persistence set to: %s\n", persistence) + return nil +} + func getStatusOutput(cmd *cobra.Command) string { var statusOutputString string statusResp, err := getStatus(cmd.Context()) diff --git a/client/cmd/networks.go b/client/cmd/networks.go new file mode 100644 index 000000000..7b9724bc5 --- /dev/null +++ b/client/cmd/networks.go @@ -0,0 +1,173 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var appendFlag bool + +var networksCMD = &cobra.Command{ + Use: "networks", + Aliases: []string{"routes"}, + Short: "Manage networks", + Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, +} + +var routesListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List networks", + Example: " netbird networks list", + Long: "List all available network routes.", + RunE: networksList, +} + +var routesSelectCmd = &cobra.Command{ + Use: "select network...|all", + Short: "Select network", + Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.", + Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3", + Args: cobra.MinimumNArgs(1), + RunE: networksSelect, +} + +var routesDeselectCmd = &cobra.Command{ + Use: "deselect network...|all", + Short: "Deselect networks", + Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.", + Example: " netbird networks deselect all\n netbird networks deselect route1 route2", + Args: cobra.MinimumNArgs(1), + RunE: networksDeselect, +} + +func init() { + routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing") +} + +func networksList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{}) + if err != nil { + return fmt.Errorf("failed to list network: %v", status.Convert(err).Message()) + } + + if len(resp.Routes) == 0 { + cmd.Println("No networks available.") + return nil + } + + printNetworks(cmd, resp) + + return nil +} + +func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) { + cmd.Println("Available Networks:") + for _, route := range resp.Routes { + printNetwork(cmd, route) + } +} + +func printNetwork(cmd *cobra.Command, route *proto.Network) { + selectedStatus := getSelectedStatus(route) + domains := route.GetDomains() + + if len(domains) > 0 { + printDomainRoute(cmd, route, domains, selectedStatus) + } else { + printNetworkRoute(cmd, route, selectedStatus) + } +} + +func getSelectedStatus(route *proto.Network) string { + if route.GetSelected() { + return "Selected" + } + return "Not Selected" +} + +func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) + resolvedIPs := route.GetResolvedIPs() + + if len(resolvedIPs) > 0 { + printResolvedIPs(cmd, domains, resolvedIPs) + } else { + cmd.Printf(" Resolved IPs: -\n") + } +} + +func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus) +} + +func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) { + cmd.Printf(" Resolved IPs:\n") + for resolvedDomain, ipList := range resolvedIPs { + cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", ")) + } +} + +func networksSelect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } else if appendFlag { + req.Append = true + } + + if _, err := client.SelectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks selected successfully.") + + return nil +} + +func networksDeselect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } + + if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks deselected successfully.") + + return nil +} diff --git a/client/cmd/pprof.go b/client/cmd/pprof.go new file mode 100644 index 000000000..37efd35f0 --- /dev/null +++ b/client/cmd/pprof.go @@ -0,0 +1,33 @@ +//go:build pprof +// +build pprof + +package cmd + +import ( + "net/http" + _ "net/http/pprof" + "os" + + log "github.com/sirupsen/logrus" +) + +func init() { + addr := pprofAddr() + go pprof(addr) +} + +func pprofAddr() string { + listenAddr := os.Getenv("NB_PPROF_ADDR") + if listenAddr == "" { + return "localhost:6969" + } + + return listenAddr +} + +func pprof(listenAddr string) { + log.Infof("listening pprof on: %s\n", listenAddr) + if err := http.ListenAndServe(listenAddr, nil); err != nil { + log.Fatalf("Failed to start pprof: %v", err) + } +} diff --git a/client/cmd/root.go b/client/cmd/root.go index 8dae6e273..0305bacc8 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -142,19 +142,20 @@ func init() { rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) - rootCmd.AddCommand(routesCmd) + rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(debugCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - routesCmd.AddCommand(routesListCmd) - routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd) + networksCMD.AddCommand(routesListCmd) + networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(logCmd) logCmd.AddCommand(logLevelCmd) debugCmd.AddCommand(forCmd) + debugCmd.AddCommand(persistenceCmd) upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, `Sets external IPs maps between local addresses and interfaces.`+ diff --git a/client/cmd/route.go b/client/cmd/route.go deleted file mode 100644 index c8881822b..000000000 --- a/client/cmd/route.go +++ /dev/null @@ -1,174 +0,0 @@ -package cmd - -import ( - "fmt" - "strings" - - "github.com/spf13/cobra" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/proto" -) - -var appendFlag bool - -var routesCmd = &cobra.Command{ - Use: "routes", - Short: "Manage network routes", - Long: `Commands to list, select, or deselect network routes.`, -} - -var routesListCmd = &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List routes", - Example: " netbird routes list", - Long: "List all available network routes.", - RunE: routesList, -} - -var routesSelectCmd = &cobra.Command{ - Use: "select route...|all", - Short: "Select routes", - Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.", - Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3", - Args: cobra.MinimumNArgs(1), - RunE: routesSelect, -} - -var routesDeselectCmd = &cobra.Command{ - Use: "deselect route...|all", - Short: "Deselect routes", - Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.", - Example: " netbird routes deselect all\n netbird routes deselect route1 route2", - Args: cobra.MinimumNArgs(1), - RunE: routesDeselect, -} - -func init() { - routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing") -} - -func routesList(cmd *cobra.Command, _ []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{}) - if err != nil { - return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message()) - } - - if len(resp.Routes) == 0 { - cmd.Println("No routes available.") - return nil - } - - printRoutes(cmd, resp) - - return nil -} - -func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) { - cmd.Println("Available Routes:") - for _, route := range resp.Routes { - printRoute(cmd, route) - } -} - -func printRoute(cmd *cobra.Command, route *proto.Route) { - selectedStatus := getSelectedStatus(route) - domains := route.GetDomains() - - if len(domains) > 0 { - printDomainRoute(cmd, route, domains, selectedStatus) - } else { - printNetworkRoute(cmd, route, selectedStatus) - } -} - -func getSelectedStatus(route *proto.Route) string { - if route.GetSelected() { - return "Selected" - } - return "Not Selected" -} - -func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) - resolvedIPs := route.GetResolvedIPs() - - if len(resolvedIPs) > 0 { - printResolvedIPs(cmd, domains, resolvedIPs) - } else { - cmd.Printf(" Resolved IPs: -\n") - } -} - -func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) -} - -func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { - cmd.Printf(" Resolved IPs:\n") - for _, domain := range domains { - if ipList, exists := resolvedIPs[domain]; exists { - cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) - } - } -} - -func routesSelect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } else if appendFlag { - req.Append = true - } - - if _, err := client.SelectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes selected successfully.") - - return nil -} - -func routesDeselect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } - - if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes deselected successfully.") - - return nil -} diff --git a/client/cmd/state.go b/client/cmd/state.go new file mode 100644 index 000000000..21a5508f4 --- /dev/null +++ b/client/cmd/state.go @@ -0,0 +1,181 @@ +package cmd + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var ( + allFlag bool +) + +var stateCmd = &cobra.Command{ + Use: "state", + Short: "Manage daemon state", + Long: "Provides commands for managing and inspecting the Netbird daemon state.", +} + +var stateListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List all stored states", + Long: "Lists all registered states with their status and basic information.", + Example: " netbird state list", + RunE: stateList, +} + +var stateCleanCmd = &cobra.Command{ + Use: "clean [state-name]", + Short: "Clean stored states", + Long: `Clean specific state or all states. The daemon must not be running. +This will perform cleanup operations and remove the state.`, + Example: ` netbird state clean dns_state + netbird state clean --all`, + RunE: stateClean, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Check mutual exclusivity between --all flag and state-name argument + if allFlag && len(args) > 0 { + return fmt.Errorf("cannot specify both --all flag and state name") + } + if !allFlag && len(args) != 1 { + return fmt.Errorf("requires a state name argument or --all flag") + } + return nil + }, +} + +var stateDeleteCmd = &cobra.Command{ + Use: "delete [state-name]", + Short: "Delete stored states", + Long: `Delete specific state or all states from storage. The daemon must not be running. +This will remove the state without performing any cleanup operations.`, + Example: ` netbird state delete dns_state + netbird state delete --all`, + RunE: stateDelete, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Check mutual exclusivity between --all flag and state-name argument + if allFlag && len(args) > 0 { + return fmt.Errorf("cannot specify both --all flag and state name") + } + if !allFlag && len(args) != 1 { + return fmt.Errorf("requires a state name argument or --all flag") + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(stateCmd) + stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd) + + stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states") + stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states") +} + +func stateList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{}) + if err != nil { + return fmt.Errorf("failed to list states: %v", status.Convert(err).Message()) + } + + cmd.Printf("\nStored states:\n\n") + for _, state := range resp.States { + cmd.Printf("- %s\n", state.Name) + } + + return nil +} + +func stateClean(cmd *cobra.Command, args []string) error { + var stateName string + if !allFlag { + stateName = args[0] + } + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{ + StateName: stateName, + All: allFlag, + }) + if err != nil { + return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message()) + } + + if resp.CleanedStates == 0 { + cmd.Println("No states were cleaned") + return nil + } + + if allFlag { + cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates) + } else { + cmd.Printf("Successfully cleaned state %q\n", stateName) + } + + return nil +} + +func stateDelete(cmd *cobra.Command, args []string) error { + var stateName string + if !allFlag { + stateName = args[0] + } + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{ + StateName: stateName, + All: allFlag, + }) + if err != nil { + return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message()) + } + + if resp.DeletedStates == 0 { + cmd.Println("No states were deleted") + return nil + } + + if allFlag { + cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates) + } else { + cmd.Printf("Successfully deleted state %q\n", stateName) + } + + return nil +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6db52a677..fa4bff77b 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -40,6 +40,7 @@ type peerStateDetailOutput struct { Latency time.Duration `json:"latency" yaml:"latency"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` } type peersStateOutput struct { @@ -98,6 +99,7 @@ type statusOutputOverview struct { RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` } @@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), - Routes: pbFullStatus.GetLocalPeerState().GetRoutes(), + Routes: pbFullStatus.GetLocalPeerState().GetNetworks(), + Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), } @@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { TransferSent: transferSent, Latency: pbPeerState.GetLatency().AsDuration(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), - Routes: pbPeerState.GetRoutes(), + Routes: pbPeerState.GetNetworks(), + Networks: pbPeerState.GetNetworks(), } peersStateDetail = append(peersStateDetail, peerState) @@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) } - routes := "-" - if len(overview.Routes) > 0 { - sort.Strings(overview.Routes) - routes = strings.Join(overview.Routes, ", ") + networks := "-" + if len(overview.Networks) > 0 { + sort.Strings(overview.Networks) + networks = strings.Join(overview.Networks, ", ") } var dnsServersString string @@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Routes: %s\n"+ + "Networks: %s\n"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), overview.DaemonVersion, @@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays interfaceIP, interfaceTypeString, rosenpassEnabledStatus, - routes, + networks, + networks, peersCountString, ) return summary @@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo } } - routes := "-" - if len(peerState.Routes) > 0 { - sort.Strings(peerState.Routes) - routes = strings.Join(peerState.Routes, ", ") + networks := "-" + if len(peerState.Networks) > 0 { + sort.Strings(peerState.Networks) + networks = strings.Join(peerState.Networks, ", ") } peerString := fmt.Sprintf( @@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Transfer status (received/sent) %s/%s\n"+ " Quantum resistance: %s\n"+ " Routes: %s\n"+ + " Networks: %s\n"+ " Latency: %s\n", peerState.FQDN, peerState.IP, @@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferReceived), toIEC(peerState.TransferSent), rosenpassEnabledStatus, - routes, + networks, + networks, peerState.Latency.String(), ) @@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeIPString(route) + } + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } @@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } } + for i, route := range overview.Networks { + overview.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range overview.Routes { overview.Routes[i] = a.AnonymizeRoute(route) } diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index ca43df8a5..1f1e95726 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), BytesRx: 200, BytesTx: 100, - Routes: []string{ + Networks: []string{ "10.1.0.0/24", }, Latency: durationpb.New(time.Duration(10000000)), @@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{ PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", - Routes: []string{ + Networks: []string{ "10.10.0.0/24", }, }, @@ -149,6 +149,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.1.0.0/24", }, + Networks: []string{ + "10.1.0.0/24", + }, Latency: time.Duration(10000000), }, { @@ -230,6 +233,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.10.0.0/24", }, + Networks: []string{ + "10.10.0.0/24", + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) { "quantumResistance": false, "routes": [ "10.1.0.0/24" + ], + "networks": [ + "10.1.0.0/24" ] }, { @@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) { "transferSent": 1000, "latency": 10000000, "quantumResistance": false, - "routes": null + "routes": null, + "networks": null } ] }, @@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) { "routes": [ "10.10.0.0/24" ], + "networks": [ + "10.10.0.0/24" + ], "dnsServers": [ { "servers": [ @@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) { quantumResistance: false routes: - 10.1.0.0/24 + networks: + - 10.1.0.0/24 - fqdn: peer-2.awesome-domain.com netbirdIp: 192.168.178.102 publicKey: Pubkey2 @@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) { latency: 10ms quantumResistance: false routes: [] + networks: [] cliVersion: development daemonVersion: 0.14.1 management: @@ -465,6 +481,8 @@ quantumResistance: false quantumResistancePermissive: false routes: - 10.10.0.0/24 +networks: + - 10.10.0.0/24 dnsServers: - servers: - 8.8.8.8:53 @@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 200 B/100 B Quantum resistance: false Routes: 10.1.0.0/24 + Networks: 10.1.0.0/24 Latency: 10ms peer-2.awesome-domain.com: @@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false Routes: - + Networks: - Latency: 10ms OS: %s/%s @@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected ` diff --git a/client/cmd/system.go b/client/cmd/system.go new file mode 100644 index 000000000..f628867a7 --- /dev/null +++ b/client/cmd/system.go @@ -0,0 +1,31 @@ +package cmd + +// Flag constants for system configuration +const ( + disableClientRoutesFlag = "disable-client-routes" + disableServerRoutesFlag = "disable-server-routes" + disableDNSFlag = "disable-dns" + disableFirewallFlag = "disable-firewall" +) + +var ( + disableClientRoutes bool + disableServerRoutes bool + disableDNS bool + disableFirewall bool +) + +func init() { + // Add system flags to upCmd + upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false, + "Disable client routes. If enabled, the client won't process client routes received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false, + "Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false, + "Disable DNS. If enabled, the client won't configure DNS settings.") + + upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, + "Disable firewall configuration. If enabled, the client won't modify firewall rules.") +} diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d998f9ea9..e3e644357 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -10,6 +10,8 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 05ecce9e0..cd5521371 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -147,6 +147,19 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ic.DNSRouteInterval = &dnsRouteInterval } + if cmd.Flag(disableClientRoutesFlag).Changed { + ic.DisableClientRoutes = &disableClientRoutes + } + if cmd.Flag(disableServerRoutesFlag).Changed { + ic.DisableServerRoutes = &disableServerRoutes + } + if cmd.Flag(disableDNSFlag).Changed { + ic.DisableDNS = &disableDNS + } + if cmd.Flag(disableFirewallFlag).Changed { + ic.DisableFirewall = &disableFirewall + } + providedSetupKey, err := getSetupKey() if err != nil { return err @@ -264,6 +277,19 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval) } + if cmd.Flag(disableClientRoutesFlag).Changed { + loginRequest.DisableClientRoutes = &disableClientRoutes + } + if cmd.Flag(disableServerRoutesFlag).Changed { + loginRequest.DisableServerRoutes = &disableServerRoutes + } + if cmd.Flag(disableDNSFlag).Changed { + loginRequest.DisableDns = &disableDNS + } + if cmd.Flag(disableFirewallFlag).Changed { + loginRequest.DisableFirewall = &disableFirewall + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/configs/configs.go b/client/configs/configs.go new file mode 100644 index 000000000..8f9c3ba28 --- /dev/null +++ b/client/configs/configs.go @@ -0,0 +1,24 @@ +package configs + +import ( + "os" + "path/filepath" + "runtime" +) + +var StateDir string + +func init() { + StateDir = os.Getenv("NB_STATE_DIR") + if StateDir != "" { + return + } + switch runtime.GOOS { + case "windows": + StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird") + case "darwin", "linux": + StateDir = "/var/lib/netbird" + case "freebsd", "openbsd", "netbsd", "dragonfly": + StateDir = "/var/db/netbird" + } +} diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 1c0527ebc..d774f4538 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error { // The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a59bd2c60..da8e2c08f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early to ensure cleanup of chains - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } @@ -195,7 +197,7 @@ func (m *Manager) AllowNetbird() error { } _, err := m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), + net.IP{0, 0, 0, 0}, "all", nil, nil, @@ -205,19 +207,9 @@ func (m *Manager) AllowNetbird() error { "", ) if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + return fmt.Errorf("allow netbird interface traffic: %w", err) } - _, err = m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - firewall.RuleDirectionOUT, - firewall.ActionAccept, - "", - "", - ) - return err + return nil } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index bfd08bee2..004c512a4 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error { return err } s.ips = temp.IPs + + if temp.IPs == nil { + temp.IPs = make(map[string]struct{}) + } + return nil } @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error { return err } s.ipsets = temp.IPSets + + if temp.IPSets == nil { + temp.IPSets = make(map[string]*ipList) + } + return nil } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index abe890fb9..852cfec8d 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "net" - "net/netip" "strconv" "strings" "time" @@ -28,7 +27,6 @@ const ( // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" - chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" chainNamePrerouting = "netbird-rt-prerouting" @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) { return err } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addFwdAllow(chain, expr.MetaKeyOIFNAME) - m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules - m.addDropExpressions(chain, expr.MetaKeyOIFNAME) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) - return err - } - // netbird-acl-forward-filter chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - dstOp := expr.CmpOpNeq - expressions := []expr.Any{ - &expr.Meta{Key: iifname, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ea8912f27..8e1aa0d80 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } @@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameForward { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { chain = c break } @@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) @@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, }, UserData: []byte(allowNetbirdInputRuleID), } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 77f4f0306..33fdc4b3d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,9 +1,11 @@ package nftables import ( + "bytes" "fmt" "net" "net/netip" + "os/exec" "testing" "time" @@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) { }) } } + +func runIptablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("iptables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + require.NoError(t, err, "iptables-save failed to run") + + return stdout.String(), stderr.String() +} + +func verifyIptablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + // Check for any incompatibility warnings + require.NotContains(t, + stderr, + "incompatible", + "iptables-save produced compatibility warning. Full stderr: %s", + stderr, + ) + + // Verify standard tables are present + expectedTables := []string{ + "*filter", + "*nat", + "*mangle", + } + + for _, table := range expectedTables { + require.Contains(t, + stdout, + table, + "iptables-save output missing expected table: %s\nFull stdout: %s", + table, + stdout, + ) + } +} + +func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Reset(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + ip := net.ParseIP("100.96.0.1") + _, err = manager.AddPeerFiltering( + ip, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "test rule", + ) + require.NoError(t, err, "failed to add peer filtering rule") + + _, err = manager.AddRouteFiltering( + []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, + netip.MustParsePrefix("10.1.0.0/24"), + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + pair := fw.RouterPair{ + Source: netip.MustParsePrefix("192.168.1.0/24"), + Destination: netip.MustParsePrefix("10.0.0.0/24"), + Masquerade: true, + } + err = manager.AddNatRule(pair) + require.NoError(t, err, "failed to add NAT rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go deleted file mode 100644 index 7027fe987..000000000 --- a/client/firewall/nftables/state.go +++ /dev/null @@ -1 +0,0 @@ -package nftables diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3c..cc0792255 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301e..0d55d6268 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 000000000..e459bc75a --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,137 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + ip := make(net.IP, 16) + return &ip + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return *p.Pool.Get().(*net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(&ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 000000000..72d006def --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = MakeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := MakeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 000000000..e0a971678 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,170 @@ +package conntrack + +import ( + "net" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + BaseConnTrack + Sequence uint16 + ID uint16 +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + switch icmpType { + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): + return true + case uint8(layers.ICMPv4TypeEchoReply): + // continue processing + default: + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.ID == id && + conn.Sequence == seq +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// makeICMPKey creates an ICMP connection key +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + return ICMPConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 000000000..21176e719 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go new file mode 100644 index 000000000..a7968dc73 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -0,0 +1,352 @@ +package conntrack + +// TODO: Send RST packets for invalid/timed-out connections + +import ( + "net" + "sync" + "time" +) + +const ( + // MSL (Maximum Segment Lifetime) is typically 2 minutes + MSL = 2 * time.Minute + // TimeWaitTimeout (TIME-WAIT) should last 2*MSL + TimeWaitTimeout = 2 * MSL +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPFin uint8 = 0x01 + TCPRst uint8 = 0x04 + TCPPush uint8 = 0x08 + TCPUrg uint8 = 0x20 +) + +const ( + // DefaultTCPTimeout is the default timeout for established TCP connections + DefaultTCPTimeout = 3 * time.Hour + // TCPHandshakeTimeout is timeout for TCP handshake completion + TCPHandshakeTimeout = 60 * time.Second + // TCPCleanupInterval is how often we check for stale connections + TCPCleanupInterval = 5 * time.Minute +) + +// TCPState represents the state of a TCP connection +type TCPState int + +const ( + TCPStateNew TCPState = iota + TCPStateSynSent + TCPStateSynReceived + TCPStateEstablished + TCPStateFinWait1 + TCPStateFinWait2 + TCPStateClosing + TCPStateTimeWait + TCPStateCloseWait + TCPStateLastAck + TCPStateClosed +) + +// TCPConnKey uniquely identifies a TCP connection +type TCPConnKey struct { + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// TCPConnTrack represents a TCP connection state +type TCPConnTrack struct { + BaseConnTrack + State TCPState + sync.RWMutex +} + +// TCPTracker manages TCP connection states +type TCPTracker struct { + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + done chan struct{} + timeout time.Duration + ipPool *PreallocatedIPs +} + +// NewTCPTracker creates a new TCP connection tracker +func NewTCPTracker(timeout time.Duration) *TCPTracker { + tracker := &TCPTracker{ + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + done: make(chan struct{}), + timeout: timeout, + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound processes an outbound TCP packet and updates connection state +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, + } + conn.lastSeen.Store(now) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + + // Lock individual connection for state update + conn.Lock() + t.updateState(conn, flags, true) + conn.Unlock() + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound TCP packet matches a tracked connection +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + // Handle RST packets + if flags&TCPRst != 0 { + conn.Lock() + if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.State = TCPStateClosed + conn.SetEstablished(false) + conn.Unlock() + return true + } + conn.Unlock() + return false + } + + conn.Lock() + t.updateState(conn, flags, false) + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() + + return isEstablished || isValidState +} + +// updateState updates the TCP connection state based on flags +func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { + // Handle RST flag specially - it always causes transition to closed + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.SetEstablished(false) + return + } + + switch conn.State { + case TCPStateNew: + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + conn.State = TCPStateSynSent + } + + case TCPStateSynSent: + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if isOutbound { + conn.State = TCPStateSynReceived + } else { + // Simultaneous open + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + } + + case TCPStateSynReceived: + if flags&TCPAck != 0 && flags&TCPSyn == 0 { + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + + case TCPStateEstablished: + if flags&TCPFin != 0 { + if isOutbound { + conn.State = TCPStateFinWait1 + } else { + conn.State = TCPStateCloseWait + } + conn.SetEstablished(false) + } + + case TCPStateFinWait1: + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: + // Simultaneous close - both sides sent FIN + conn.State = TCPStateClosing + case flags&TCPFin != 0: + conn.State = TCPStateFinWait2 + case flags&TCPAck != 0: + conn.State = TCPStateFinWait2 + } + + case TCPStateFinWait2: + if flags&TCPFin != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateClosing: + if flags&TCPAck != 0 { + conn.State = TCPStateTimeWait + // Keep established = false from previous state + } + + case TCPStateCloseWait: + if flags&TCPFin != 0 { + conn.State = TCPStateLastAck + } + + case TCPStateLastAck: + if flags&TCPAck != 0 { + conn.State = TCPStateClosed + } + + case TCPStateTimeWait: + // Stay in TIME-WAIT for 2MSL before transitioning to closed + // This is handled by the cleanup routine + } +} + +// isValidStateForFlags checks if the TCP flags are valid for the current connection state +func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + switch state { + case TCPStateNew: + return flags&TCPSyn != 0 && flags&TCPAck == 0 + case TCPStateSynSent: + return flags&TCPSyn != 0 && flags&TCPAck != 0 + case TCPStateSynReceived: + return flags&TCPAck != 0 + case TCPStateEstablished: + if flags&TCPRst != 0 { + return true + } + return flags&TCPAck != 0 + case TCPStateFinWait1: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateClosing: + // In CLOSING state, we should accept the final ACK + return flags&TCPAck != 0 + case TCPStateTimeWait: + // In TIME_WAIT, we might see retransmissions + return flags&TCPAck != 0 + case TCPStateCloseWait: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateLastAck: + return flags&TCPAck != 0 + case TCPStateClosed: + // Accept retransmitted ACKs in closed state + // This is important because the final ACK might be lost + // and the peer will retransmit their FIN-ACK + return flags&TCPAck != 0 + } + return false +} + +func (t *TCPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *TCPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + var timeout time.Duration + switch { + case conn.State == TCPStateTimeWait: + timeout = TimeWaitTimeout + case conn.IsEstablished(): + timeout = t.timeout + default: + timeout = TCPHandshakeTimeout + } + + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *TCPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +func isValidFlagCombination(flags uint8) bool { + // Invalid: SYN+FIN + if flags&TCPSyn != 0 && flags&TCPFin != 0 { + return false + } + + // Invalid: RST with SYN or FIN + if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) { + return false + } + + return true +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go new file mode 100644 index 000000000..6c8f82423 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -0,0 +1,308 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTCPStateMachine(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Security Tests", func(t *testing.T) { + tests := []struct { + name string + flags uint8 + wantDrop bool + desc string + }{ + { + name: "Block unsolicited SYN-ACK", + flags: TCPSyn | TCPAck, + wantDrop: true, + desc: "Should block SYN-ACK without prior SYN", + }, + { + name: "Block invalid SYN-FIN", + flags: TCPSyn | TCPFin, + wantDrop: true, + desc: "Should block invalid SYN-FIN combination", + }, + { + name: "Block unsolicited RST", + flags: TCPRst, + wantDrop: true, + desc: "Should block RST without connection", + }, + { + name: "Block unsolicited ACK", + flags: TCPAck, + wantDrop: true, + desc: "Should block ACK without connection", + }, + { + name: "Block data without connection", + flags: TCPAck | TCPPush, + wantDrop: true, + desc: "Should block data without established connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + require.Equal(t, !tt.wantDrop, isValid, tt.desc) + }) + } + }) + + t.Run("Connection Flow Tests", func(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + desc string + }{ + { + name: "Normal Handshake", + test: func(t *testing.T) { + t.Helper() + + // Send initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + // Receive SYN-ACK + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + // Send ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + + // Test data transfer + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.True(t, valid, "Data should be allowed after handshake") + }, + }, + { + name: "Normal Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "ACK for FIN should be allowed") + + // Receive FIN from other side + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "FIN should be allowed") + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + }, + { + name: "RST During Connection", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Receive RST + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + require.True(t, valid, "RST should be allowed for established connection") + + // Connection is logically dead but we don't enforce blocking subsequent packets + // The connection will be cleaned up by timeout + }, + }, + { + name: "Simultaneous Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Both sides send FIN+ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "Simultaneous FIN should be allowed") + + // Both sides send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "Final ACKs should be allowed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + + tracker = NewTCPTracker(DefaultTCPTimeout) + tt.test(t) + }) + } + }) +} + +func TestRSTHandling(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + tests := []struct { + name string + setupState func() + sendRST func() + wantValid bool + desc string + }{ + { + name: "RST in established", + setupState: func() { + // Establish connection first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: true, + desc: "Should accept RST for established connection", + }, + { + name: "RST without connection", + setupState: func() {}, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: false, + desc: "Should reject RST without connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + tt.sendRST() + + // Verify connection state is as expected + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + if tt.wantValid { + require.NotNil(t, conn) + require.Equal(t, TCPStateClosed, conn.State) + require.False(t, conn.IsEstablished()) + } + }) + } +} + +// Helper to establish a TCP connection +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + t.Helper() + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go new file mode 100644 index 000000000..a969a4e84 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -0,0 +1,158 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second +) + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + BaseConnTrack +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[ConnKey]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultUDPTimeout + } + + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(UDPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// GetConnection safely retrieves a connection state +func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := t.connections[key] + if !exists { + return nil, false + } + + return conn, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go new file mode 100644 index 000000000..671721890 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -0,0 +1,243 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultUDPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := tracker.connections[key] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + assert.Len(t, tracker.connections, 2) + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + tracker.mutex.RLock() + connCount := len(tracker.connections) + tracker.mutex.RUnlock() + + // Verify connections were cleaned up + assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up") + + // Properly close the tracker + tracker.Close() +} + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index af5dc6733..ebe04caee 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "net/netip" + "os" + "strconv" "sync" "github.com/google/gopacket" @@ -12,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +22,8 @@ import ( const layerTypeAll = 0 +const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -42,6 +47,11 @@ type Manager struct { nativeFirewall firewall.Manager mutex sync.RWMutex + + stateful bool + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker } // decoder for packages @@ -73,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager } func create(iface IFaceMapper) (*Manager, error) { + disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -90,6 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + stateful: !disableConntrack, + } + + // Only initialize trackers if stateful mode is enabled + if disableConntrack { + log.Info("conntrack is disabled") + } else { + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } if err := iface.SetFilter(m); err != nil { @@ -239,7 +261,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { // SetLegacyManagement doesn't need to be implemented for this manager func (m *Manager) SetLegacyManagement(isLegacy bool) error { if m.nativeFirewall == nil { - return errRouteNotSupported + return nil } return m.nativeFirewall.SetLegacyManagement(isLegacy) } @@ -249,16 +271,16 @@ func (m *Manager) Flush() error { return nil } // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.dropFilter(packetData, m.outgoingRules, false) + return m.processOutgoingHooks(packetData) } // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules, true) + return m.dropFilter(packetData, m.incomingRules) } -// dropFilter implements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { +// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP +func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -266,61 +288,215 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco defer m.decoders.Put(d) if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) - return true + return false } if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") + return false + } + + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + return false + } + + // Always process UDP hooks + if d.decoded[1] == layers.LayerTypeUDP { + // Track UDP state only if enabled + if m.stateful { + m.trackUDPOutbound(d, srcIP, dstIP) + } + return m.checkUDPHooks(d, dstIP, packetData) + } + + // Track other protocols only if stateful mode is enabled + if m.stateful { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) + } + } + + return false +} + +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return d.ip4.SrcIP, d.ip4.DstIP + case layers.LayerTypeIPv6: + return d.ip6.SrcIP, d.ip6.DstIP + default: + return nil, nil + } +} + +func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + flags, + ) +} + +func getTCPFlags(tcp *layers.TCP) uint8 { + var flags uint8 + if tcp.SYN { + flags |= conntrack.TCPSyn + } + if tcp.ACK { + flags |= conntrack.TCPAck + } + if tcp.FIN { + flags |= conntrack.TCPFin + } + if tcp.RST { + flags |= conntrack.TCPRst + } + if tcp.PSH { + flags |= conntrack.TCPPush + } + if tcp.URG { + flags |= conntrack.TCPUrg + } + return flags +} + +func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) +} + +func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } + } + } + } + return false +} + +func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } +} + +// dropFilter implements filtering logic for incoming packets +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { + // TODO: Disable router if --disable-server-router is set + + m.mutex.RLock() + defer m.mutex.RUnlock() + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if !m.isValidPacket(d, packetData) { return true } - ipLayer := d.decoded[0] - - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false - } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { - return false - } - default: + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { log.Errorf("unknown layer: %v", d.decoded[0]) return true } - var ip net.IP - switch ipLayer { - case layers.LayerTypeIPv4: - if isIncomingPacket { - ip = d.ip4.SrcIP - } else { - ip = d.ip4.DstIP - } - case layers.LayerTypeIPv6: - if isIncomingPacket { - ip = d.ip6.SrcIP - } else { - ip = d.ip6.DstIP - } + if !m.isWireguardTraffic(srcIP, dstIP) { + return false } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) - if ok { - return filter + // Check connection state only if enabled + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + return false } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) - if ok { - return filter + + return m.applyRules(srcIP, packetData, rules, d) +} + +func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + log.Tracef("couldn't decode layer, err: %s", err) + return false } - filter, ok = validateRule(ip, packetData, rules["::"], d) - if ok { + + if len(d.decoded) < 2 { + log.Tracef("not enough levels in network packet") + return false + } + return true +} + +func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { + return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) +} + +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return m.tcpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + ) + + case layers.LayerTypeUDP: + return m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + case layers.LayerTypeICMPv4: + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + d.icmp4.TypeCode.Type(), + ) + + // TODO: ICMPv6 + } + + return false +} + +func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } - // default policy is DROP ALL + if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { + return filter + } + + if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { + return filter + } + + // Default policy: DROP ALL return true } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go new file mode 100644 index 000000000..3c661e71c --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -0,0 +1,998 @@ +package uspfilter + +import ( + "fmt" + "math/rand" + "net" + "os" + "strings" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface/device" +) + +// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range +func generateRandomIPs(n int) []net.IP { + ips := make([]net.IP, n) + seen := make(map[string]bool) + + for i := 0; i < n; { + ip := make(net.IP, 4) + ip[0] = 100 + ip[1] = byte(64 + rand.Intn(63)) // 64-126 + ip[2] = byte(rand.Intn(256)) + ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255 + + key := ip.String() + if !seen[key] { + ips[i] = ip + seen[key] = true + i++ + } + } + return ips +} + +func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + } + + var transportLayer gopacket.SerializableLayer + switch protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = udp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(b, err) + return buf.Bytes() +} + +// BenchmarkCoreFiltering focuses on the essential performance comparisons between +// stateful and stateless filtering approaches +func BenchmarkCoreFiltering(b *testing.B) { + scenarios := []struct { + name string + stateful bool + setupFunc func(*Manager) + desc string + }{ + { + name: "stateless_single_allow_all", + stateful: false, + setupFunc: func(m *Manager) { + // Single rule allowing all traffic + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + require.NoError(b, err) + }, + desc: "Baseline: Single 'allow all' rule without connection tracking", + }, + { + name: "stateful_no_rules", + stateful: true, + setupFunc: func(m *Manager) { + // No explicit rules - rely purely on connection tracking + }, + desc: "Pure connection tracking without any rules", + }, + { + name: "stateless_explicit_return", + stateful: false, + setupFunc: func(m *Manager) { + // Add explicit rules matching return traffic pattern + for i := 0; i < 1000; i++ { // Simulate realistic ruleset size + ip := generateRandomIPs(1)[0] + _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, + &fw.Port{Values: []int{1024 + i}}, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + require.NoError(b, err) + } + }, + desc: "Explicit rules matching return traffic patterns without state", + }, + { + name: "stateful_with_established", + stateful: true, + setupFunc: func(m *Manager) { + // Add some basic rules but rely on state for established connections + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + require.NoError(b, err) + }, + desc: "Connection tracking with established connections", + }, + } + + // Test both TCP and UDP + protocols := []struct { + name string + proto layers.IPProtocol + }{ + {"TCP", layers.IPProtocolTCP}, + {"UDP", layers.IPProtocolUDP}, + } + + for _, sc := range scenarios { + for _, proto := range protocols { + b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + } else { + require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m")) + } + + // Create manager and basic setup + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Apply scenario-specific setup + sc.setupFunc(manager) + + // Generate test packets + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + srcPort := uint16(1024 + b.N%60000) + dstPort := uint16(80) + + outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto) + + // For stateful scenarios, establish the connection + if sc.stateful { + manager.processOutgoingHooks(outbound) + } + + // Measure inbound packet processing + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } + } +} + +// BenchmarkStateScaling measures how performance scales with connection table size +func BenchmarkStateScaling(b *testing.B) { + connCounts := []int{100, 1000, 10000, 100000} + + for _, count := range connCounts { + b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Pre-populate connection table + srcIPs := generateRandomIPs(count) + dstIPs := generateRandomIPs(count) + for i := 0; i < count; i++ { + outbound := generatePacket(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, layers.IPProtocolTCP) + manager.processOutgoingHooks(outbound) + } + + // Test packet + testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + + // First establish our test connection + manager.processOutgoingHooks(testOut) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(testIn, manager.incomingRules) + } + }) + } +} + +// BenchmarkEstablishmentOverhead measures the overhead of connection establishment +func BenchmarkEstablishmentOverhead(b *testing.B) { + scenarios := []struct { + name string + established bool + }{ + {"established", true}, + {"new", false}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + + if sc.established { + manager.processOutgoingHooks(outbound) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic +func BenchmarkRoutedNetworkReturn(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + state string // "new", "established", "post_handshake" (TCP only) + setupFunc func(*Manager) + genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario + desc string + }{ + { + name: "allow_non_wg_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Allow non-WG: TCP new connection", + }, + { + name: "allow_non_wg_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with ACK flag for established connection + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Allow non-WG: TCP established connection", + }, + { + name: "allow_non_wg_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP new connection", + }, + { + name: "allow_non_wg_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP established connection", + }, + { + name: "stateful_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Stateful: TCP new connection", + }, + { + name: "stateful_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate established TCP packets (ACK flag) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Stateful: TCP established connection", + }, + { + name: "stateful_tcp_post_handshake", + proto: layers.IPProtocolTCP, + state: "post_handshake", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with PSH+ACK flags for data transfer + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + desc: "Stateful: TCP post-handshake data transfer", + }, + { + name: "stateful_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP new connection", + }, + { + name: "stateful_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP established connection", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + // Setup scenario + sc.setupFunc(manager) + + // Use IPs outside WG range for routed network simulation + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("8.8.8.8") + outbound, inbound := sc.genPackets(srcIP, dstIP) + + // For stateful cases and established connections + if !strings.Contains(sc.name, "allow_non_wg") || + (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { + manager.processOutgoingHooks(outbound) + + // For TCP post-handshake, simulate full handshake + if sc.state == "post_handshake" { + // SYN + syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + // ACK + ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +var scenarios = []struct { + name string + stateful bool // Whether conntrack is enabled + rules bool // Whether to add return traffic rules + routed bool // Whether to test routed network traffic + connCount int // Number of concurrent connections + desc string +}{ + { + name: "stateless_with_rules_100conns", + stateful: false, + rules: true, + routed: false, + connCount: 100, + desc: "Pure stateless with return traffic rules, 100 conns", + }, + { + name: "stateless_with_rules_1000conns", + stateful: false, + rules: true, + routed: false, + connCount: 1000, + desc: "Pure stateless with return traffic rules, 1000 conns", + }, + { + name: "stateful_no_rules_100conns", + stateful: true, + rules: false, + routed: false, + connCount: 100, + desc: "Pure stateful tracking without rules, 100 conns", + }, + { + name: "stateful_no_rules_1000conns", + stateful: true, + rules: false, + routed: false, + connCount: 1000, + desc: "Pure stateful tracking without rules, 1000 conns", + }, + { + name: "stateful_with_rules_100conns", + stateful: true, + rules: true, + routed: false, + connCount: 100, + desc: "Combined stateful + rules (current implementation), 100 conns", + }, + { + name: "stateful_with_rules_1000conns", + stateful: true, + rules: true, + routed: false, + connCount: 1000, + desc: "Combined stateful + rules (current implementation), 1000 conns", + }, + { + name: "routed_network_100conns", + stateful: true, + rules: false, + routed: true, + connCount: 100, + desc: "Routed network traffic (non-WG), 100 conns", + }, + { + name: "routed_network_1000conns", + stateful: true, + rules: false, + routed: true, + connCount: 1000, + desc: "Routed network traffic (non-WG), 1000 conns", + }, +} + +// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns +func BenchmarkLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + // Initial SYN + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + // ACK + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Prepare test packets simulating bidirectional traffic + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + // Server -> Client (inbound) + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + // Client -> Server (outbound) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.connCount + + // Simulate bidirectional traffic + // First outbound data + manager.processOutgoingHooks(outPackets[connIdx]) + // Then inbound response - this is what we're actually measuring + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + } +} + +// BenchmarkShortLivedConnections tests performance with many short-lived connections +func BenchmarkShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create packet patterns for a complete HTTP-like short connection: + // 1. Initial handshake (SYN, SYN-ACK, ACK) + // 2. HTTP Request (PSH+ACK from client) + // 3. HTTP Response (PSH+ACK from server) + // 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK) + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + // Generate all possible connection patterns + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + // Handshake + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + + // Data transfer + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + + // Connection teardown + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Each iteration creates a new short-lived connection + connIdx := i % sc.connCount + p := patterns[connIdx] + + // Connection establishment + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + // Data transfer + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + // Connection teardown + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + } +} + +// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel +func BenchmarkParallelLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Pre-generate test packets + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Each goroutine gets its own counter to distribute load + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + + // Simulate bidirectional traffic + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + }) + } +} + +// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel +func BenchmarkParallelShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs and pre-generate all packet patterns + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + p := patterns[connIdx] + + // Full connection lifecycle + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + }) + } +} + +// generateTCPPacketWithFlags creates a TCP packet with specific flags +func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + + // Set TCP flags + tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 + tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 + tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 + tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 + tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d7c93cb7f..d3563e6f2 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -185,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -313,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -325,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.outgoingRules, false) { + if m.dropFilter(buf.Bytes(), m.outgoingRules) { t.Errorf("expected packet to be accepted") return } @@ -348,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -384,6 +389,88 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err = udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload("test") + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { @@ -418,3 +505,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload("test"), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.DropOutgoing(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +} diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go new file mode 100644 index 000000000..b8a865e39 --- /dev/null +++ b/client/iface/bind/control_android.go @@ -0,0 +1,12 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func init() { + // ControlFns is not thread safe and should only be modified during init. + *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) +} diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 12f7a8129..00a91f0ec 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} case addr.IP.To16() != nil: networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + default: params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) } diff --git a/client/iface/device/kernel_module_linux.go b/client/iface/device/kernel_module_linux.go index 0d195779d..b28ddd36c 100644 --- a/client/iface/device/kernel_module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -27,14 +27,14 @@ import ( type status int const ( - defaultModuleDir = "/lib/modules" - unknown status = iota - unloaded - unloading - loading - live - inuse - envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" + unknown status = 1 + unloaded status = 2 + unloading status = 3 + loading status = 4 + live status = 5 + inuse status = 6 + defaultModuleDir = "/lib/modules" + envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" ) type module struct { diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index c77e39fe0..09889a57e 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -15,6 +15,10 @@ func IsEnabled() bool { func ListenAddr() string { sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT") + if sPort == "" { + return listenAddr(DefaultSocks5Port) + } + port, err := strconv.Atoi(sPort) if err != nil { log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port) diff --git a/client/internal/config.go b/client/internal/config.go index ee54c6380..594bdc570 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -46,6 +46,7 @@ type ConfigInput struct { ManagementURL string AdminURL string ConfigPath string + StateFilePath string PreSharedKey *string ServerSSHAllowed *bool NATExternalIPs []string @@ -60,6 +61,11 @@ type ConfigInput struct { DNSRouteInterval *time.Duration ClientCertPath string ClientCertKeyPath string + + DisableClientRoutes *bool + DisableServerRoutes *bool + DisableDNS *bool + DisableFirewall *bool } // Config Configuration type @@ -77,6 +83,12 @@ type Config struct { RosenpassEnabled bool RosenpassPermissive bool ServerSSHAllowed *bool + + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + // SSHKey is a private SSH key in a PEM format SSHKey string @@ -105,10 +117,10 @@ type Config struct { // DNSRouteInterval is the interval in which the DNS routes are updated DNSRouteInterval time.Duration - //Path to a certificate used for mTLS authentication + // Path to a certificate used for mTLS authentication ClientCertPath string - //Path to corresponding private key of ClientCertPath + // Path to corresponding private key of ClientCertPath ClientCertKeyPath string ClientCertKeyPair *tls.Certificate `json:"-"` @@ -116,7 +128,7 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { - if configFileIsExists(configPath) { + if fileExists(configPath) { err := util.EnforcePermission(configPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) @@ -149,7 +161,7 @@ func ReadConfig(configPath string) (*Config, error) { // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { return nil, status.Errorf(codes.NotFound, "config file doesn't exist") } @@ -158,13 +170,13 @@ func UpdateConfig(input ConfigInput) (*Config, error) { // UpdateOrCreateConfig reads existing config or generates a new one func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { return nil, err } - err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) return cfg, err } @@ -185,7 +197,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) { // WriteOutConfig write put the prepared config to the given path func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(path, config) + return util.WriteJson(context.Background(), path, config) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -215,7 +227,7 @@ func update(input ConfigInput) (*Config, error) { } if updated { - if err := util.WriteJson(input.ConfigPath, config); err != nil { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { return nil, err } } @@ -401,7 +413,46 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { config.DNSRouteInterval = dynamic.DefaultInterval log.Infof("using default DNS route interval %s", config.DNSRouteInterval) updated = true + } + if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes { + if *input.DisableClientRoutes { + log.Infof("disabling client routes") + } else { + log.Infof("enabling client routes") + } + config.DisableClientRoutes = *input.DisableClientRoutes + updated = true + } + + if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes { + if *input.DisableServerRoutes { + log.Infof("disabling server routes") + } else { + log.Infof("enabling server routes") + } + config.DisableServerRoutes = *input.DisableServerRoutes + updated = true + } + + if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS { + if *input.DisableDNS { + log.Infof("disabling DNS configuration") + } else { + log.Infof("enabling DNS configuration") + } + config.DisableDNS = *input.DisableDNS + updated = true + } + + if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall { + if *input.DisableFirewall { + log.Infof("disabling firewall configuration") + } else { + log.Infof("enabling firewall configuration") + } + config.DisableFirewall = *input.DisableFirewall + updated = true } if input.ClientCertKeyPath != "" { @@ -472,11 +523,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { return false } -func configFileIsExists(path string) bool { +func fileExists(path string) bool { _, err := os.Stat(path) return !os.IsNotExist(err) } +func createFile(path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + return file.Close() +} + // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. diff --git a/client/internal/connect.go b/client/internal/connect.go index dff44f1d2..afd1f4454 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -40,6 +40,8 @@ type ConnectClient struct { statusRecorder *peer.Status engine *Engine engineMutex sync.Mutex + + persistNetworkMap bool } func NewConnectClient( @@ -89,6 +91,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. debug.SetGCPercent(5) @@ -97,6 +100,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, nil) } @@ -157,7 +161,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold engineCtx, cancel := context.WithCancel(c.ctx) defer func() { - c.statusRecorder.MarkManagementDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkManagementDisconnected(err) c.statusRecorder.CleanLocalPeerState() cancel() }() @@ -231,6 +236,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold relayURLs, token := parseRelayInfo(loginResp) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { if err := relayManager.UpdateToken(token); err != nil { @@ -241,9 +247,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) if err = relayManager.Serve(); err != nil { log.Error(err) - return wrapErr(err) } - c.statusRecorder.SetRelayMgr(relayManager) } peerConfig := loginResp.GetPeerConfig() @@ -258,7 +262,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold c.engineMutex.Lock() c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) - + c.engine.SetNetworkMapPersistence(c.persistNetworkMap) c.engineMutex.Unlock() if err := c.engine.Start(); err != nil { @@ -336,6 +340,19 @@ func (c *ConnectClient) Engine() *Engine { return e } +// Status returns the current client status +func (c *ConnectClient) Status() StatusType { + if c == nil { + return StatusIdle + } + status, err := CtxGetState(c.ctx).Status() + if err != nil { + return StatusIdle + } + + return status +} + func (c *ConnectClient) Stop() error { if c == nil { return nil @@ -362,6 +379,21 @@ func (c *ConnectClient) isContextCancelled() bool { } } +// SetNetworkMapPersistence enables or disables network map persistence. +// When enabled, the last received network map will be stored and can be retrieved +// through the Engine's getLatestNetworkMap method. When disabled, any stored +// network map will be cleared. +func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { + c.engineMutex.Lock() + c.persistNetworkMap = enabled + c.engineMutex.Unlock() + + engine := c.Engine() + if engine != nil { + engine.SetNetworkMapPersistence(enabled) + } +} + // createEngineConfig converts configuration received from Management Service to EngineConfig func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { nm := false @@ -383,6 +415,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe RosenpassPermissive: config.RosenpassPermissive, ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), DNSRouteInterval: config.DNSRouteInterval, + + DisableClientRoutes: config.DisableClientRoutes, + DisableServerRoutes: config.DisableServerRoutes, + DisableDNS: config.DisableDNS, + DisableFirewall: config.DisableFirewall, } if config.PreSharedKey != "" { diff --git a/client/internal/dns/consts.go b/client/internal/dns/consts.go new file mode 100644 index 000000000..b333d0808 --- /dev/null +++ b/client/internal/dns/consts.go @@ -0,0 +1,18 @@ +//go:build !android + +package dns + +import ( + "github.com/netbirdio/netbird/client/configs" + "os" + "path/filepath" +) + +var fileUncleanShutdownResolvConfLocation string + +func init() { + fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE") + if fileUncleanShutdownResolvConfLocation == "" { + fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf") + } +} diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go deleted file mode 100644 index 64c8fe5eb..000000000 --- a/client/internal/dns/consts_freebsd.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" -) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go deleted file mode 100644 index 15614b0c5..000000000 --- a/client/internal/dns/consts_linux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !android - -package dns - -const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" -) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go new file mode 100644 index 000000000..5f63d1ab3 --- /dev/null +++ b/client/internal/dns/handler_chain.go @@ -0,0 +1,225 @@ +package dns + +import ( + "slices" + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +const ( + PriorityDNSRoute = 100 + PriorityMatchDomain = 50 + PriorityDefault = 0 +) + +type SubdomainMatcher interface { + dns.Handler + MatchSubdomains() bool +} + +type HandlerEntry struct { + Handler dns.Handler + Priority int + Pattern string + OrigPattern string + IsWildcard bool + StopHandler handlerWithStop + MatchSubdomains bool +} + +// HandlerChain represents a prioritized chain of DNS handlers +type HandlerChain struct { + mu sync.RWMutex + handlers []HandlerEntry +} + +// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain +type ResponseWriterChain struct { + dns.ResponseWriter + origPattern string + shouldContinue bool +} + +func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { + // Check if this is a continue signal (NXDOMAIN with Zero bit set) + if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero { + w.shouldContinue = true + return nil + } + return w.ResponseWriter.WriteMsg(m) +} + +func NewHandlerChain() *HandlerChain { + return &HandlerChain{ + handlers: make([]HandlerEntry, 0), + } +} + +// GetOrigPattern returns the original pattern of the handler that wrote the response +func (w *ResponseWriterChain) GetOrigPattern() string { + return w.origPattern +} + +// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = strings.ToLower(dns.Fqdn(pattern)) + origPattern := pattern + isWildcard := strings.HasPrefix(pattern, "*.") + if isWildcard { + pattern = pattern[2:] + } + + // First remove any existing handler with same pattern (case-insensitive) and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { + if c.handlers[i].StopHandler != nil { + c.handlers[i].StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + break + } + } + + // Check if handler implements SubdomainMatcher interface + matchSubdomains := false + if matcher, ok := handler.(SubdomainMatcher); ok { + matchSubdomains = matcher.MatchSubdomains() + } + + log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + pattern, origPattern, isWildcard, matchSubdomains, priority) + + entry := HandlerEntry{ + Handler: handler, + Priority: priority, + Pattern: pattern, + OrigPattern: origPattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + MatchSubdomains: matchSubdomains, + } + + // Insert handler in priority order + pos := 0 + for i, h := range c.handlers { + if h.Priority < priority { + pos = i + break + } + pos = i + 1 + } + + c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) +} + +// RemoveHandler removes a handler for the given pattern and priority +func (c *HandlerChain) RemoveHandler(pattern string, priority int) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = dns.Fqdn(pattern) + + // Find and remove handlers matching both original pattern (case-insensitive) and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + entry := c.handlers[i] + if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { + if entry.StopHandler != nil { + entry.StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + return + } + } +} + +// HasHandlers returns true if there are any handlers remaining for the given pattern +func (c *HandlerChain) HasHandlers(pattern string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + pattern = strings.ToLower(dns.Fqdn(pattern)) + for _, entry := range c.handlers { + if strings.EqualFold(entry.Pattern, pattern) { + return true + } + } + return false +} + +func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + + qname := strings.ToLower(r.Question[0].Name) + log.Tracef("handling DNS request for domain=%s", qname) + + c.mu.RLock() + handlers := slices.Clone(c.handlers) + c.mu.RUnlock() + + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("current handlers (%d):", len(handlers)) + for _, h := range handlers { + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } + } + + // Try handlers in priority order + for _, entry := range handlers { + var matched bool + switch { + case entry.Pattern == ".": + matched = true + case entry.IsWildcard: + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + default: + // For non-wildcard patterns: + // If handler wants subdomain matching, allow suffix match + // Otherwise require exact match + if entry.MatchSubdomains { + matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) + } else { + matched = strings.EqualFold(qname, entry.Pattern) + } + } + + if !matched { + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) + continue + } + + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + + chainWriter := &ResponseWriterChain{ + ResponseWriter: w, + origPattern: entry.OrigPattern, + } + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + log.Tracef("handler requested continue to next handler") + continue + } + return + } + + // No handler matched or all handlers passed + log.Tracef("no handler found for domain=%s", qname) + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go new file mode 100644 index 000000000..eb40c907f --- /dev/null +++ b/client/internal/dns/handler_chain_test.go @@ -0,0 +1,679 @@ +package dns_test + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" +) + +// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order +func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create mock handlers for different priorities + defaultHandler := &nbdns.MockHandler{} + matchDomainHandler := &nbdns.MockHandler{} + dnsRouteHandler := &nbdns.MockHandler{} + + // Setup handlers with different priorities + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Create test writer + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations - only highest priority handler should be called + dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() + matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() + + // Execute + chain.ServeDNS(w, r) + + // Verify all expectations were met + dnsRouteHandler.AssertExpectations(t) + matchDomainHandler.AssertExpectations(t) + defaultHandler.AssertExpectations(t) +} + +// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios +func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { + tests := []struct { + name string + handlerDomain string + queryDomain string + isWildcard bool + matchSubdomains bool + shouldMatch bool + }{ + { + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains true", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains false", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handler dns.Handler + + if tt.matchSubdomains { + mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + handler = mockSubHandler + if tt.shouldMatch { + mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } else { + mockHandler := &nbdns.MockHandler{} + handler = mockHandler + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } + + pattern := tt.handlerDomain + if tt.isWildcard { + pattern = "*." + tt.handlerDomain[2:] + } + + chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) + + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + chain.ServeDNS(w, r) + + if h, ok := handler.(*nbdns.MockHandler); ok { + h.AssertExpectations(t) + } else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok { + h.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns +func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { + tests := []struct { + name string + handlers []struct { + pattern string + priority int + } + queryDomain string + expectedCalls int + expectedHandler int // index of the handler that should be called + }{ + { + name: "wildcard and exact same priority - exact should win", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // exact match handler should be called + }, + { + name: "higher priority wildcard over lower priority exact", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority wildcard handler should be called + }, + { + name: "multiple wildcards different priorities", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority handler should be called + }, + { + name: "subdomain with mix of patterns", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "sub.test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority matching handler should be called + }, + { + name: "root zone with specific domain", + handlers: []struct { + pattern string + priority int + }{ + {pattern: ".", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority specific domain should win over root + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handlers []*nbdns.MockHandler + + // Setup handlers and expectations + for i := range tt.handlers { + handler := &nbdns.MockHandler{} + handlers = append(handlers, handler) + + // Set expectation based on whether this handler should be called + if i == tt.expectedHandler { + handler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } else { + handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() + } + + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + } + + // Create and execute request + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality +func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create handlers + handler1 := &nbdns.MockHandler{} + handler2 := &nbdns.MockHandler{} + handler3 := &nbdns.MockHandler{} + + // Add handlers in priority order + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Setup mock responses to simulate chain continuation + handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // First handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true // Signal to continue + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Second handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Last handler responds normally + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + // Execute + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify all handlers were called in order + handler1.AssertExpectations(t) + handler2.AssertExpectations(t) + handler3.AssertExpectations(t) +} + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +func TestHandlerChain_PriorityDeregistration(t *testing.T) { + tests := []struct { + name string + ops []struct { + action string // "add" or "remove" + pattern string + priority int + } + query string + expectedCalls map[int]bool // map[priority]shouldBeCalled + }{ + { + name: "remove high priority keeps lower priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: true, + }, + }, + { + name: "remove lower priority keeps high priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: true, + nbdns.PriorityMatchDomain: false, + }, + }, + { + name: "remove all handlers in order", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityDefault}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: false, + nbdns.PriorityDefault: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlers := make(map[int]*nbdns.MockHandler) + + // Execute operations + for _, op := range tt.ops { + if op.action == "add" { + handler := &nbdns.MockHandler{} + handlers[op.priority] = handler + chain.AddHandler(op.pattern, handler, op.priority, nil) + } else { + chain.RemoveHandler(op.pattern, op.priority) + } + } + + // Create test request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations + for priority, handler := range handlers { + if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall { + handler.On("ServeDNS", mock.Anything, r).Once() + } else { + handler.On("ServeDNS", mock.Anything, r).Maybe() + } + } + + // Execute request + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + + // Verify handler exists check + for priority, shouldExist := range tt.expectedCalls { + if shouldExist { + assert.True(t, chain.HasHandlers(tt.ops[0].pattern), + "Handler chain should have handlers for pattern after removing priority %d", priority) + } + } + }) + } +} + +func TestHandlerChain_MultiPriorityHandling(t *testing.T) { + chain := nbdns.NewHandlerChain() + + testDomain := "example.com." + testQuery := "test.example.com." + + // Create handlers with MatchSubdomains enabled + routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + + // Create test request that will be reused + r := new(dns.Msg) + r.SetQuestion(testQuery, dns.TypeA) + + // Add handlers in mixed order + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + + // Test 1: Initial state with all three handlers + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Highest priority handler (routeHandler) should be called + routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + routeHandler.AssertExpectations(t) + + // Test 2: Remove highest priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now middle priority handler (matchHandler) should be called + matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + matchHandler.AssertExpectations(t) + + // Test 3: Remove middle priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now lowest priority handler (defaultHandler) should be called + defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + defaultHandler.AssertExpectations(t) + + // Test 4: Remove last handler + chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + + assert.False(t, chain.HasHandlers(testDomain)) +} + +func TestHandlerChain_CaseSensitivity(t *testing.T) { + tests := []struct { + name string + scenario string + addHandlers []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + } + query string + expectedCalls int + }{ + { + name: "case insensitive exact match", + scenario: "handler registered lowercase, query uppercase", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "case insensitive wildcard match", + scenario: "handler registered mixed case wildcard, query different case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"*.Example.Com.", nbdns.PriorityDefault, false, true}, + }, + query: "sub.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different case same domain", + scenario: "second handler should replace first despite case difference", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "ExAmPlE.cOm.", + expectedCalls: 1, + }, + { + name: "subdomain matching case insensitive", + scenario: "handler with MatchSubdomains true should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, true, true}, + }, + query: "SUB.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "root zone case insensitive", + scenario: "root zone handler should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {".", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different priority", + scenario: "should call higher priority handler despite case differences", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, + }, + query: "example.com.", + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlerCalls := make(map[string]bool) // track which patterns were called + + // Add handlers according to test case + for _, h := range tt.addHandlers { + var handler dns.Handler + pattern := h.pattern // capture pattern for closure + + if h.subdomains { + subHandler := &nbdns.MockSubdomainHandler{ + Subdomains: true, + } + if h.shouldMatch { + subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = subHandler + } else { + mockHandler := &nbdns.MockHandler{} + if h.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = mockHandler + } + + chain.AddHandler(pattern, handler, h.priority, nil) + } + + // Execute request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + chain.ServeDNS(&mockResponseWriter{}, r) + + // Verify each handler was called exactly as expected + for _, h := range tt.addHandlers { + wasCalled := handlerCalls[h.pattern] + assert.Equal(t, h.shouldMatch, wasCalled, + "Handler for pattern %q was %s when it should%s have been", + h.pattern, + map[bool]string{true: "called", false: "not called"}[wasCalled], + map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch]) + } + + // Verify total number of calls + assert.Equal(t, tt.expectedCalls, len(handlerCalls), + "Wrong number of total handler calls") + }) + } +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e2b5f699a..fbe8c4dbb 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD return config } + +type noopHostConfigurator struct{} + +func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { + return nil +} + +func (n noopHostConfigurator) restoreHostDNS() error { + return nil +} + +func (n noopHostConfigurator) supportCustomPort() bool { + return true +} diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 6a459794b..9a78d4d50 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -17,12 +17,24 @@ type localResolver struct { records sync.Map } +func (d *localResolver) MatchSubdomains() bool { + return true +} + func (d *localResolver) stop() { } +// String returns a string representation of the local resolver +func (d *localResolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v", r.Question[0]) + if len(r.Question) > 0 { + log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + } + replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0739f0542..7e36ea5df 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,14 +3,30 @@ package dns import ( "fmt" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func([]string, dns.Handler, int) + DeregisterHandlerFunc func([]string, int) +} + +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + if m.RegisterHandlerFunc != nil { + m.RegisterHandlerFunc(domains, handler, priority) + } +} + +func (m *MockServer) DeregisterHandler(domains []string, priority int) { + if m.DeregisterHandlerFunc != nil { + m.DeregisterHandlerFunc(domains, priority) + } } // Initialize mock implementation of Initialize from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 929e1e60c..bb097c4cb 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -31,6 +30,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { + RegisterHandler(domains []string, handler dns.Handler, priority int) + DeregisterHandler(domains []string, priority int) Initialize() error Stop() DnsIP() string @@ -46,15 +47,18 @@ type registeredHandlerMap map[string]handlerWithStop type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc + disableSys bool mux sync.Mutex service service dnsMuxMap registeredHandlerMap + handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + handlerChain *HandlerChain // permanent related properties permanent bool @@ -75,12 +79,20 @@ type handlerWithStop interface { } type muxUpdate struct { - domain string - handler handlerWithStop + domain string + handler handlerWithStop + priority int } // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { +func NewDefaultServer( + ctx context.Context, + wgInterface WGIface, + customAddress string, + statusRecorder *peer.Status, + stateManager *statemanager.Manager, + disableSys bool, +) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -97,7 +109,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -108,9 +120,10 @@ func NewDefaultServerPermanentUpstream( config nbdns.Config, listener listener.NetworkChangeListener, statusRecorder *peer.Status, + disableSys bool, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -127,19 +140,30 @@ func NewDefaultServerIos( wgInterface WGIface, iosDnsManager IosDnsManager, statusRecorder *peer.Status, + disableSys bool, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { +func newDefaultServer( + ctx context.Context, + wgInterface WGIface, + dnsService service, + statusRecorder *peer.Status, + stateManager *statemanager.Manager, + disableSys bool, +) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - service: dnsService, - dnsMuxMap: make(registeredHandlerMap), + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), + handlerPriorities: make(map[string]int), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -152,6 +176,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.registerHandler(domains, handler, priority) +} + +func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { + log.Debugf("registering handler %s with priority %d", handler, priority) + + for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.handlerChain.AddHandler(domain, handler, priority, nil) + s.handlerPriorities[domain] = priority + s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) + } +} + +func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.deregisterHandler(domains, priority) +} + +func (s *DefaultServer) deregisterHandler(domains []string, priority int) { + log.Debugf("deregistering handler %v with priority %d", domains, priority) + + for _, domain := range domains { + s.handlerChain.RemoveHandler(domain, priority) + + // Only deregister from service if no handlers remain + if !s.handlerChain.HasHandlers(domain) { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.service.DeregisterMux(nbdns.NormalizeZone(domain)) + } + } +} + // Initialize instantiate host manager and the dns service func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() @@ -169,6 +238,13 @@ func (s *DefaultServer) Initialize() (err error) { } s.stateManager.RegisterState(&ShutdownState{}) + + if s.disableSys { + log.Info("system DNS is disabled, not setting up host manager") + s.hostManager = &noopHostConfigurator{} + return nil + } + s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -217,47 +293,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { // UpdateDNSServer processes an update received from the management service func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { - select { - case <-s.ctx.Done(): + if s.ctx.Err() != nil { log.Infof("not updating DNS server as context is closed") return s.ctx.Err() - default: - if serial < s.updateSerial { - return fmt.Errorf("not applying dns update, error: "+ - "network update is %d behind the last applied update", s.updateSerial-serial) - } - s.mux.Lock() - defer s.mux.Unlock() + } - if s.hostManager == nil { - return fmt.Errorf("dns service is not initialized yet") - } + if serial < s.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", s.updateSerial-serial) + } - hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ - ZeroNil: true, - IgnoreZeroValue: true, - SlicesAsSets: true, - UseStringer: true, - }) - if err != nil { - log.Errorf("unable to hash the dns configuration update, got error: %s", err) - } + s.mux.Lock() + defer s.mux.Unlock() - if s.previousConfigHash == hash { - log.Debugf("not applying the dns configuration update as there is nothing new") - s.updateSerial = serial - return nil - } + if s.hostManager == nil { + return fmt.Errorf("dns service is not initialized yet") + } - if err := s.applyConfiguration(update); err != nil { - return fmt.Errorf("apply configuration: %w", err) - } + hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Errorf("unable to hash the dns configuration update, got error: %s", err) + } + if s.previousConfigHash == hash { + log.Debugf("not applying the dns configuration update as there is nothing new") s.updateSerial = serial - s.previousConfigHash = hash - return nil } + + if err := s.applyConfiguration(update); err != nil { + return fmt.Errorf("apply configuration: %w", err) + } + + s.updateSerial = serial + s.previousConfigHash = hash + + return nil } func (s *DefaultServer) SearchDomains() []string { @@ -323,12 +399,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - log.Errorf("Failed to persist dns state: %v", err) - } + go func() { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + }() if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) @@ -344,14 +420,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { - if len(customZone.Records) == 0 { return nil, nil, fmt.Errorf("received an empty list of records") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, + domain: customZone.Domain, + handler: s.localResolver, + priority: PriorityMatchDomain, }) for _, record := range customZone.Records { @@ -413,8 +489,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, + domain: nbdns.RootZone, + handler: handler, + priority: PriorityDefault, }) continue } @@ -430,8 +507,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, + domain: domain, + handler: handler, + priority: PriorityMatchDomain, }) } } @@ -441,12 +519,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { muxUpdateMap := make(registeredHandlerMap) + handlersByPriority := make(map[string]int) var isContainRootUpdate bool + // First register new handlers for _, update := range muxUpdates { - s.service.RegisterMux(update.domain, update.handler) + s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.domain] = update.handler + handlersByPriority[update.domain] = update.priority + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { existingHandler.stop() } @@ -456,6 +538,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { } } + // Then deregister old handlers not in the update for key, existingHandler := range s.dnsMuxMap { _, found := muxUpdateMap[key] if !found { @@ -464,12 +547,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { existingHandler.stop() } else { existingHandler.stop() - s.service.DeregisterMux(key) + // Deregister with the priority that was used to register + if oldPriority, ok := s.handlerPriorities[key]; ok { + s.deregisterHandler([]string{key}, oldPriority) + } } } } s.dnsMuxMap = muxUpdateMap + s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -518,13 +605,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.service.DeregisterMux(nbdns.RootZone) + s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.service.DeregisterMux(item.Domain) + s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) removeIndex[item.Domain] = i } } @@ -533,12 +620,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() @@ -556,7 +642,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) + s.registerHandler([]string{domain}, handler, PriorityMatchDomain) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -564,10 +650,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.service.RegisterMux(nbdns.RootZone, handler) + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } - if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + + if s.hostManager != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + } } s.updateNSState(nsGroup, nil, true) @@ -595,7 +684,8 @@ func (s *DefaultServer) addHostRootZone() { } handler.deactivate = func(error) {} handler.reactivate = func() {} - s.service.RegisterMux(nbdns.RootZone, handler) + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21f1f1b7d..c166820c4 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -292,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) if err != nil { t.Fatal(err) } @@ -401,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -496,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false) if err != nil { t.Fatalf("%v", err) } @@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) { t.Error(err) } - dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver) + dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) resolver := &net.Resolver{ PreferGo: true, @@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - hostManager: hostManager, + handlerChain: NewHandlerChain(), + handlerPriorities: make(map[string]int), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -629,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { var dnsList []string dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -653,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -745,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -782,7 +786,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { Port: 53, }, }, - Domains: []string{"customdomain.com"}, + Domains: []string{"google.com"}, Primary: false, }, }, @@ -804,7 +808,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { if ips[0] != zoneRecords[0].RData { t.Fatalf("invalid zone record: %v", err) } - _, err = resolver.LookupHost(context.Background(), "customdomain.com") + _, err = resolver.LookupHost(context.Background(), "google.com") if err != nil { t.Errorf("failed to resolve: %s", err) } @@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver { }, } } + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +type MockSubdomainHandler struct { + MockHandler + Subdomains bool +} + +func (m *MockSubdomainHandler) MatchSubdomains() bool { + return m.Subdomains +} + +func TestHandlerChain_DomainPriorities(t *testing.T) { + chain := NewHandlerChain() + + dnsRouteHandler := &MockHandler{} + upstreamHandler := &MockSubdomainHandler{ + Subdomains: true, + } + + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + + testCases := []struct { + name string + query string + expectedHandler dns.Handler + }{ + { + name: "exact domain with dns route handler", + query: "example.com.", + expectedHandler: dnsRouteHandler, + }, + { + name: "subdomain should use upstream handler", + query: "sub.example.com.", + expectedHandler: upstreamHandler, + }, + { + name: "deep subdomain should use upstream handler", + query: "deep.sub.example.com.", + expectedHandler: upstreamHandler, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := new(dns.Msg) + r.SetQuestion(tc.query, dns.TypeA) + w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.AssertExpectations(t) + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.AssertExpectations(t) + } + + // Reset mocks + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } + }) + } +} diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index e0f9da26f..72dc4bc6e 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() { } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8..f0aa12b65 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * } } +// String returns a string representation of the upstream resolver +func (u *upstreamResolverBase) String() string { + return fmt.Sprintf("upstream %v", u.upstreamServers) +} + +func (u *upstreamResolverBase) MatchSubdomains() bool { + return true +} + func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go new file mode 100644 index 000000000..ae31ffac6 --- /dev/null +++ b/client/internal/dnsfwd/forwarder.go @@ -0,0 +1,157 @@ +package dnsfwd + +import ( + "context" + "errors" + "net" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +const errResolveFailed = "failed to resolve query for domain=%s: %v" + +type DNSForwarder struct { + listenAddress string + ttl uint32 + domains []string + + dnsServer *dns.Server + mux *dns.ServeMux +} + +func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder { + log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) + return &DNSForwarder{ + listenAddress: listenAddress, + ttl: ttl, + } +} + +func (f *DNSForwarder) Listen(domains []string) error { + log.Infof("listen DNS forwarder on address=%s", f.listenAddress) + mux := dns.NewServeMux() + + dnsServer := &dns.Server{ + Addr: f.listenAddress, + Net: "udp", + Handler: mux, + } + f.dnsServer = dnsServer + f.mux = mux + + f.UpdateDomains(domains) + + return dnsServer.ListenAndServe() +} + +func (f *DNSForwarder) UpdateDomains(domains []string) { + log.Debugf("Updating domains from %v to %v", f.domains, domains) + + for _, d := range f.domains { + f.mux.HandleRemove(d) + } + + newDomains := filterDomains(domains) + for _, d := range newDomains { + f.mux.HandleFunc(d, f.handleDNSQuery) + } + f.domains = newDomains +} + +func (f *DNSForwarder) Close(ctx context.Context) error { + if f.dnsServer == nil { + return nil + } + return f.dnsServer.ShutdownContext(ctx) +} + +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { + if len(query.Question) == 0 { + return + } + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) + + question := query.Question[0] + domain := question.Name + + resp := query.SetReply(query) + + ips, err := net.LookupIP(domain) + if err != nil { + var dnsErr *net.DNSError + + switch { + case errors.As(err, &dnsErr): + resp.Rcode = dns.RcodeServerFailure + if dnsErr.IsNotFound { + // Pass through NXDOMAIN + resp.Rcode = dns.RcodeNameError + } + + if dnsErr.Server != "" { + log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + } else { + log.Warnf(errResolveFailed, domain, err) + } + default: + resp.Rcode = dns.RcodeServerFailure + log.Warnf(errResolveFailed, domain, err) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } + return + } + + for _, ip := range ips { + var respRecord dns.RR + if ip.To4() == nil { + log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) + rr := dns.AAAA{ + AAAA: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } else { + log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) + rr := dns.A{ + A: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } + resp.Answer = append(resp.Answer, respRecord) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +// filterDomains returns a list of normalized domains +func filterDomains(domains []string) []string { + newDomains := make([]string, 0, len(domains)) + for _, d := range domains { + if d == "" { + log.Warn("empty domain in DNS forwarder") + continue + } + newDomains = append(newDomains, nbdns.NormalizeZone(d)) + } + return newDomains +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go new file mode 100644 index 000000000..e6dfd278e --- /dev/null +++ b/client/internal/dnsfwd/manager.go @@ -0,0 +1,111 @@ +package dnsfwd + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + ListenPort = 5353 + dnsTTL = 60 //seconds +) + +type Manager struct { + firewall firewall.Manager + + fwRules []firewall.Rule + dnsForwarder *DNSForwarder +} + +func NewManager(fw firewall.Manager) *Manager { + return &Manager{ + firewall: fw, + } +} + +func (m *Manager) Start(domains []string) error { + log.Infof("starting DNS forwarder") + if m.dnsForwarder != nil { + return nil + } + + if err := m.allowDNSFirewall(); err != nil { + return err + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL) + go func() { + if err := m.dnsForwarder.Listen(domains); err != nil { + // todo handle close error if it is exists + log.Errorf("failed to start DNS forwarder, err: %v", err) + } + }() + + return nil +} + +func (m *Manager) UpdateDomains(domains []string) { + if m.dnsForwarder == nil { + return + } + + m.dnsForwarder.UpdateDomains(domains) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m.dnsForwarder == nil { + return nil + } + + var mErr *multierror.Error + if err := m.dropDNSFirewall(); err != nil { + mErr = multierror.Append(mErr, err) + } + + if err := m.dnsForwarder.Close(ctx); err != nil { + mErr = multierror.Append(mErr, err) + } + + m.dnsForwarder = nil + return nberrors.FormatErrorOrNil(mErr) +} + +func (h *Manager) allowDNSFirewall() error { + dport := &firewall.Port{ + IsRange: false, + Values: []int{ListenPort}, + } + + if h.firewall == nil { + return nil + } + + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + h.fwRules = dnsRules + + return nil +} + +func (h *Manager) dropDNSFirewall() error { + var mErr *multierror.Error + for _, rule := range h.fwRules { + if err := h.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } + + h.fwRules = nil + return nberrors.FormatErrorOrNil(mErr) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cd..b50532b7d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,13 +4,13 @@ import ( "context" "errors" "fmt" - "maps" "math/rand" "net" "net/netip" "reflect" "runtime" "slices" + "sort" "strings" "sync" "sync/atomic" @@ -20,6 +20,7 @@ import ( "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" @@ -28,16 +29,18 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -61,6 +64,7 @@ import ( const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms + connInitLimit = 200 ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -104,6 +108,11 @@ type EngineConfig struct { ServerSSHAllowed bool DNSRouteInterval time.Duration + + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -114,7 +123,7 @@ type Engine struct { // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer - peerConns map[string]*peer.Conn + peerStore *peerstore.Store beforePeerHook nbnet.AddHookFunc afterPeerHook nbnet.RemoveHookFunc @@ -134,10 +143,6 @@ type Engine struct { TURNs []*stun.URI stunTurn atomic.Value - // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - clientRoutesMu sync.RWMutex - clientCtx context.Context clientCancel context.CancelFunc @@ -158,9 +163,10 @@ type Engine struct { statusRecorder *peer.Status - firewall manager.Manager - routeManager routemanager.Manager - acl acl.Manager + firewall manager.Manager + routeManager routemanager.Manager + acl acl.Manager + dnsForwardMgr *dnsfwd.Manager dnsServer dns.Server @@ -171,7 +177,12 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + srWatcher *guard.SRWatcher + + // Network map persistence + persistNetworkMap bool + latestNetworkMap *mgmProto.NetworkMap + connSemaphore *semaphoregroup.SemaphoreGroup } // Peer is an instance of the Connection Peer @@ -226,7 +237,7 @@ func NewEngineWithProbes( signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, relayManager: relayManager, - peerConns: make(map[string]*peer.Conn), + peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, mobileDep: mobileDep, @@ -237,6 +248,18 @@ func NewEngineWithProbes( statusRecorder: statusRecorder, probes: probes, checks: checks, + connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + } + if runtime.GOOS == "ios" { + if !fileExists(mobileDep.StateFilePath) { + err := createFile(mobileDep.StateFilePath) + if err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + + engine.stateManager = statemanager.New(mobileDep.StateFilePath) } if path := statemanager.GetDefaultStatePath(); path != "" { engine.stateManager = statemanager.New(path) @@ -267,19 +290,26 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.dnsForwardMgr != nil { + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } + if e.srWatcher != nil { e.srWatcher.Close() } + e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) + e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) + e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = nil - e.clientRoutesMu.Unlock() - if e.cancel != nil { e.cancel() } @@ -297,7 +327,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } @@ -349,8 +379,21 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ + Context: e.ctx, + PublicKey: e.config.WgPrivateKey.PublicKey().String(), + DNSRouteInterval: e.config.DNSRouteInterval, + WGInterface: e.wgInterface, + StatusRecorder: e.statusRecorder, + RelayManager: e.relayManager, + InitialRoutes: initialRoutes, + StateManager: e.stateManager, + DNSServer: dnsServer, + PeerStore: e.peerStore, + DisableClientRoutes: e.config.DisableClientRoutes, + DisableServerRoutes: e.config.DisableServerRoutes, + }) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { @@ -367,17 +410,8 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) - if err != nil { - log.Errorf("failed creating firewall manager: %s", err) - } - - if e.firewall != nil && e.firewall.IsServerRouteSupported() { - err = e.routeManager.EnableServerRouter(e.firewall) - if err != nil { - e.close() - return fmt.Errorf("enable server router: %w", err) - } + if err := e.createFirewall(); err != nil { + return err } e.udpMux, err = e.wgInterface.Up() @@ -419,6 +453,61 @@ func (e *Engine) Start() error { return nil } +func (e *Engine) createFirewall() error { + if e.config.DisableFirewall { + log.Infof("firewall is disabled") + return nil + } + + var err error + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) + if err != nil || e.firewall == nil { + log.Errorf("failed creating firewall manager: %s", err) + return nil + } + + if err := e.initFirewall(); err != nil { + return err + } + + return nil +} + +func (e *Engine) initFirewall() error { + if e.firewall.IsServerRouteSupported() { + if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { + e.close() + return fmt.Errorf("enable server router: %w", err) + } + } + + if e.rpManager == nil || !e.config.RosenpassEnabled { + return nil + } + + rosenpassPort := e.rpManager.GetAddress().Port + port := manager.Port{Values: []int{rosenpassPort}} + + // this rule is static and will be torn down on engine down by the firewall manager + if _, err := e.firewall.AddPeerFiltering( + net.IP{0, 0, 0, 0}, + manager.ProtocolUDP, + nil, + &port, + manager.RuleDirectionIN, + manager.ActionAccept, + "", + "", + ); err != nil { + log.Errorf("failed to allow rosenpass interface traffic: %v", err) + return nil + } + + log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort) + + return nil +} + // modifyPeers updates peers that have been modified (e.g. IP address has been changed). // It closes the existing connection, removes it from the peerConns map, and creates a new one. func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { @@ -427,8 +516,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if peerConn, ok := e.peerConns[peerPubKey]; ok { - if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { + if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { + if allowedIPs != strings.Join(p.AllowedIps, ",") { modified = append(modified, p) continue } @@ -459,17 +548,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { - currentPeers := make([]string, 0, len(e.peerConns)) - for p := range e.peerConns { - currentPeers = append(currentPeers, p) - } - newPeers := make([]string, 0, len(peersUpdate)) for _, p := range peersUpdate { newPeers = append(newPeers, p.GetWgPubKey()) } - toRemove := util.SliceDiff(currentPeers, newPeers) + toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) for _, p := range toRemove { err := e.removePeer(p) @@ -483,7 +567,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") - for p := range e.peerConns { + for _, p := range e.peerStore.PeersPubKey() { err := e.removePeer(p) if err != nil { return err @@ -507,9 +591,8 @@ func (e *Engine) removePeer(peerKey string) error { } }() - conn, exists := e.peerConns[peerKey] + conn, exists := e.peerStore.Remove(peerKey) if exists { - delete(e.peerConns, peerKey) conn.Close() } return nil @@ -538,6 +621,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { relayMsg := wCfg.GetRelay() if relayMsg != nil { + // when we receive token we expect valid address list too c := &auth.Token{ Payload: relayMsg.GetTokenPayload(), Signature: relayMsg.GetTokenSignature(), @@ -546,9 +630,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { log.Errorf("failed to update relay token: %v", err) return fmt.Errorf("update relay token: %w", err) } + + e.relayManager.UpdateServerURLs(relayMsg.Urls) + + // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. + // We can ignore all errors because the guard will manage the reconnection retries. + _ = e.relayManager.Serve() + } else { + e.relayManager.UpdateServerURLs(nil) } - // todo update relay address in the relay manager // todo update signal } @@ -556,13 +647,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return err } - if update.GetNetworkMap() != nil { - // only apply new changes and ignore old ones - err := e.updateNetworkMap(update.GetNetworkMap()) - if err != nil { - return err - } + nm := update.GetNetworkMap() + if nm == nil { + return nil } + + // Store network map if persistence is enabled + if e.persistNetworkMap { + e.latestNetworkMap = nm + log.Debugf("network map persisted with serial %d", nm.GetSerial()) + } + + // only apply new changes and ignore old ones + if err := e.updateNetworkMap(nm); err != nil { + return err + } + return nil } @@ -641,6 +741,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { + if e.wgInterface == nil { + return errors.New("wireguard interface is not initialized") + } + if e.wgInterface.Address().String() != conf.Address { oldAddr := e.wgInterface.Address().String() log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) @@ -659,12 +763,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { } } - e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ - IP: e.config.WgAddr, - PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: device.WireGuardModuleIsLoaded(), - FQDN: conf.GetFqdn(), - }) + state := e.statusRecorder.GetLocalPeerState() + state.IP = e.config.WgAddr + state.PubKey = e.config.WgPrivateKey.PublicKey().String() + state.KernelInterface = device.WireGuardModuleIsLoaded() + state.FQDN = conf.GetFqdn() + + e.statusRecorder.UpdateLocalPeerState(state) return nil } @@ -732,7 +837,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { - // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't if networkMap.GetPeerConfig() != nil { err := e.updateConfig(networkMap.GetPeerConfig()) @@ -752,20 +856,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } + // DNS forwarder + dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) + dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) + e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains) - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { + routes := toRoutes(networkMap.GetRoutes()) + if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -813,8 +913,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -827,7 +926,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } +func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool { + if networkMap.PeerConfig != nil { + return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled + } + return false +} + func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -838,6 +948,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { continue } } + convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), Network: prefix, @@ -854,6 +965,23 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { return routes } +func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string + for _, protoRoute := range protoRoutes { + if len(protoRoute.Domains) == 0 { + continue + } + if protoRoute.Peer == myPubKey { + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + } + } + return dnsRoutes +} + func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), @@ -928,12 +1056,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerConns[peerKey]; !ok { + if _, ok := e.peerStore.PeerConn(peerKey); !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - e.peerConns[peerKey] = conn + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } if e.beforePeerHook != nil && e.afterPeerHook != nil { conn.AddBeforeAddPeerHook(e.beforePeerHook) @@ -1001,7 +1133,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) if err != nil { return nil, err } @@ -1022,8 +1154,8 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn := e.peerConns[msg.Key] - if conn == nil { + conn, ok := e.peerStore.PeerConn(msg.Key) + if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } @@ -1081,7 +1213,7 @@ func (e *Engine) receiveSignalEvents() { return err } - go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: } @@ -1239,6 +1371,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { if e.dnsServer != nil { return nil, e.dnsServer, nil } + switch runtime.GOOS { case "android": routes, dnsConfig, err := e.readInitialSettings() @@ -1252,14 +1385,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { *dnsConfig, e.mobileDep.NetworkChangeListener, e.statusRecorder, + e.config.DisableDNS, ) go e.mobileDep.DnsReadyListener.OnReady() return routes, dnsServer, nil + case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) return nil, dnsServer, nil + default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) if err != nil { return nil, nil, err } @@ -1268,26 +1404,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } -// GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() route.HAMap { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - return maps.Clone(e.clientRoutes) -} - -// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) - for id, v := range e.clientRoutes { - routes[id.NetID()] = v - } - return routes -} - // GetRouteManager returns the route manager func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager @@ -1372,9 +1488,8 @@ func (e *Engine) receiveProbeEvents() { go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") - for _, peer := range e.peerConns { - key := peer.GetKey() - wgStats, err := peer.WgConfig().WgInterface.GetStats(key) + for _, key := range e.peerStore.PeersPubKey() { + wgStats, err := e.wgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } @@ -1451,7 +1566,7 @@ func (e *Engine) startNetworkMonitor() { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { var vpnRoutes []netip.Prefix - for _, routes := range e.GetClientRoutes() { + for _, routes := range e.routeManager.GetClientRoutes() { if len(routes) > 0 && routes[0] != nil { vpnRoutes = append(vpnRoutes, routes[0].Network) } @@ -1479,8 +1594,93 @@ func (e *Engine) stopDNSServer() { e.statusRecorder.UpdateDNSStates(nsGroupStates) } +// SetNetworkMapPersistence enables or disables network map persistence +func (e *Engine) SetNetworkMapPersistence(enabled bool) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + if enabled == e.persistNetworkMap { + return + } + e.persistNetworkMap = enabled + log.Debugf("Network map persistence is set to %t", enabled) + + if !enabled { + e.latestNetworkMap = nil + } +} + +// GetLatestNetworkMap returns the stored network map if persistence is enabled +func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + if !e.persistNetworkMap { + return nil, errors.New("network map persistence is disabled") + } + + if e.latestNetworkMap == nil { + //nolint:nilnil + return nil, nil + } + + log.Debugf("Retrieving latest network map with size %d bytes", proto.Size(e.latestNetworkMap)) + nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap) + if !ok { + + return nil, fmt.Errorf("failed to clone network map") + } + + return nm, nil +} + +// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag +func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { + if !enabled { + if e.dnsForwardMgr == nil { + return + } + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + return + } + + if len(domains) > 0 { + log.Infof("enable domain router service for domains: %v", domains) + if e.dnsForwardMgr == nil { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + + if err := e.dnsForwardMgr.Start(domains); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } + } else { + log.Infof("update domain router service for domains: %v", domains) + e.dnsForwardMgr.UpdateDomains(domains) + } + } else if e.dnsForwardMgr != nil { + log.Infof("disable domain router service") + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { + for _, check := range checks { + sort.Slice(check.Files, func(i, j int) bool { + return check.Files[i] < check.Files[j] + }) + } + for _, oCheck := range oChecks { + sort.Slice(oCheck.Files, func(i, j int) bool { + return oCheck.Files[i] < oCheck.Files[j] + }) + } + return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 0018af6df..1deea1cb8 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -39,6 +39,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -245,12 +247,22 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { nil) wgIface := &iface.MockWGIface{ + NameFunc: func() string { return "utun102" }, RemovePeerFunc: func(peerKey string) error { return nil }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) + engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ + Context: ctx, + PublicKey: key.PublicKey().String(), + DNSRouteInterval: time.Minute, + WGInterface: engine.wgInterface, + StatusRecorder: engine.statusRecorder, + RelayManager: relayMgr, + }) + _, _, err = engine.routeManager.Init() + require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } @@ -388,8 +400,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { return } - if len(engine.peerConns) != c.expectedLen { - t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns)) + if len(engine.peerStore.PeersPubKey()) != c.expectedLen { + t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey())) } if engine.networkSerial != c.expectedSerial { @@ -397,7 +409,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } for _, p := range c.expectedPeers { - conn, ok := engine.peerConns[p.GetWgPubKey()] + conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey()) if !ok { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } @@ -622,10 +634,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { input.inputSerial = updateSerial input.inputRoutes = newRoutes - return nil, nil, testCase.inputErr + return testCase.inputErr }, } @@ -798,8 +810,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { - return nil, nil, nil + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + return nil }, } @@ -1006,6 +1018,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { } } +func Test_CheckFilesEqual(t *testing.T) { + testCases := []struct { + name string + inputChecks1 []*mgmtProto.Checks + inputChecks2 []*mgmtProto.Checks + expectedBool bool + }{ + { + name: "Equal Files In Equal Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + expectedBool: true, + }, + { + name: "Equal Files In Reverse Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile2", + "testfile1", + }, + }, + }, + expectedBool: true, + }, + { + name: "Unequal Files Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile3", + }, + }, + }, + expectedBool: false, + }, + { + name: "Compared With Empty Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{}, + }, + }, + expectedBool: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2) + assert.Equal(t, testCase.expectedBool, result, "result should match expected bool") + }) + } +} + func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -1100,7 +1205,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } @@ -1122,7 +1227,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } @@ -1141,7 +1246,8 @@ func getConnectedPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() i := 0 - for _, conn := range e.peerConns { + for _, id := range e.peerStore.PeersPubKey() { + conn, _ := e.peerStore.PeerConn(id) if conn.Status() == peer.StatusConnected { i++ } @@ -1153,5 +1259,5 @@ func getPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - return len(e.peerConns) + return len(e.peerStore.PeersPubKey()) } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2b0c92cc6..4ac0fc141 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -19,4 +19,5 @@ type MobileDependency struct { // iOS only DnsManager dns.IosDnsManager FileDescriptor int32 + StateFilePath string } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 81c456db7..b8cb2582f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -23,6 +23,7 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) type ConnPriority int @@ -83,7 +84,6 @@ type Conn struct { signaler *Signaler relayManager *relayClient.Manager allowedIP net.IP - allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -105,13 +105,14 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - guard *guard.Guard + guard *guard.Guard + semaphore *semaphoregroup.SemaphoreGroup } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { - allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { + allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -129,9 +130,9 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, relayManager: relayManager, allowedIP: allowedIP, - allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), + semaphore: semaphore, } rFns := WorkerRelayCallbacks{ @@ -171,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open() { + conn.semaphore.Add(conn.ctx) conn.log.Debugf("open connection to peer") conn.mu.Lock() @@ -193,6 +195,7 @@ func (conn *Conn) Open() { } func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + defer conn.semaphore.Done(conn.ctx) conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() @@ -594,14 +597,13 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) } } func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { - minWait := 100 - maxWait := 800 - duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond + maxWait := 300 + duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond timeout := time.NewTimer(duration) defer timeout.Stop() @@ -745,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// AllowedIP returns the allowed IP of the remote peer +func (conn *Conn) AllowedIP() net.IP { + return conn.allowedIP +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 039952588..b3e9d5b60 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) var connConf = ConnConfig{ @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { func TestConn_GetKey(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0444dc60b..0df2a2e81 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -17,6 +17,11 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" ) +type ResolvedDomainInfo struct { + Prefixes []netip.Prefix + ParentDomain domain.Domain +} + // State contains the latest state of a peer type State struct { Mux *sync.RWMutex @@ -79,6 +84,12 @@ type LocalPeerState struct { Routes map[string]struct{} } +// Clone returns a copy of the LocalPeerState +func (l LocalPeerState) Clone() LocalPeerState { + l.Routes = maps.Clone(l.Routes) + return l +} + // SignalState contains the latest state of a signal connection type SignalState struct { URL string @@ -138,7 +149,7 @@ type Status struct { rosenpassEnabled bool rosenpassPermissive bool nsGroupStates []NSGroupState - resolvedDomainsStates map[domain.Domain][]netip.Prefix + resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -156,7 +167,7 @@ func NewRecorder(mgmAddress string) *Status { offlinePeers: make([]State, 0), notifier: newNotifier(), mgmAddress: mgmAddress, - resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix), + resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{}, } } @@ -496,7 +507,7 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetLocalPeerState() LocalPeerState { d.mux.Lock() defer d.mux.Unlock() - return d.localPeer + return d.localPeer.Clone() } // UpdateLocalPeerState updates local peer status @@ -591,16 +602,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { d.mux.Lock() defer d.mux.Unlock() - d.resolvedDomainsStates[domain] = prefixes + + // Store both the original domain pattern and resolved domain + d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{ + Prefixes: prefixes, + ParentDomain: originalDomain, + } } func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { d.mux.Lock() defer d.mux.Unlock() - delete(d.resolvedDomainsStates, domain) + + // Remove all entries that have this domain as their parent + for k, v := range d.resolvedDomainsStates { + if v.ParentDomain == domain { + delete(d.resolvedDomainsStates, k) + } + } } func (d *Status) GetRosenpassState() RosenpassState { @@ -676,25 +698,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // extend the list of stun, turn servers with relay address relayStates := slices.Clone(d.relayStates) - var relayState relay.ProbeResult - // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address instanceAddr, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status - if errors.Is(err, relayClient.ErrRelayClientNotConnected) { - for _, r := range d.relayMgr.ServerURLs() { - relayStates = append(relayStates, relay.ProbeResult{ - URI: r, - }) - } - return relayStates + for _, r := range d.relayMgr.ServerURLs() { + relayStates = append(relayStates, relay.ProbeResult{ + URI: r, + Err: err, + }) } - relayState.Err = err + return relayStates } - relayState.URI = instanceAddr + relayState := relay.ProbeResult{ + URI: instanceAddr, + } return append(relayStates, relayState) } @@ -704,7 +724,7 @@ func (d *Status) GetDNSStates() []NSGroupState { return d.nsGroupStates } -func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { +func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { d.mux.Lock() defer d.mux.Unlock() return maps.Clone(d.resolvedDomainsStates) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4c67cb781..4cdd18ff1 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -46,8 +46,6 @@ type WorkerICE struct { hasRelayOnLocally bool conn WorkerICECallbacks - selectedPriority ConnPriority - agent *ice.Agent muxAgent sync.Mutex @@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { - w.selectedPriority = connPriorityICEP2P preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { - w.selectedPriority = connPriorityICETurn preferredCandidateTypes = icemaker.CandidateTypes() } @@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { RelayedOnLocal: isRelayCandidate(pair.Local), } w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(w.selectedPriority, ci) + go w.conn.OnConnReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -268,7 +264,13 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) + addrString := pair.Remote.Address() + parsed, err := netip.ParseAddr(addrString) + if (err == nil) && (parsed.Is6()) { + addrString = fmt.Sprintf("[%s]", addrString) + //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() + } + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return @@ -394,3 +396,11 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } + +func selectedPriority(pair *ice.CandidatePair) ConnPriority { + if isRelayed(pair) { + return connPriorityICETurn + } else { + return connPriorityICEP2P + } +} diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go new file mode 100644 index 000000000..6b3385ff5 --- /dev/null +++ b/client/internal/peerstore/store.go @@ -0,0 +1,87 @@ +package peerstore + +import ( + "net" + "sync" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +// Store is a thread-safe store for peer connections. +type Store struct { + peerConns map[string]*peer.Conn + peerConnsMu sync.RWMutex +} + +func NewConnStore() *Store { + return &Store{ + peerConns: make(map[string]*peer.Conn), + } +} + +func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + _, ok := s.peerConns[pubKey] + if ok { + return false + } + + s.peerConns[pubKey] = conn + return true +} + +func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + delete(s.peerConns, pubKey) + return p, true +} + +func (s *Store) AllowedIPs(pubKey string) (string, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return "", false + } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p.AllowedIP(), true +} + +func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p, true +} + +func (s *Store) PeersPubKey() []string { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + return maps.Keys(s.peerConns) +} diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 13e45b3a3..73f552aab 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -13,12 +13,20 @@ import ( "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) +const ( + handlerTypeDynamic = iota + handlerTypeDomain + handlerTypeStatic +) + type routerPeerStatus struct { connected bool relayed bool @@ -53,7 +61,18 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher( + ctx context.Context, + dnsRouteInterval time.Duration, + wgInterface iface.IWGIface, + statusRecorder *peer.Status, + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ @@ -65,7 +84,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), + handler: handlerFromRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouteInterval, + statusRecorder, + wgInterface, + dnsServer, + peerStore, + useNewDNSRoute, + ), } return client } @@ -368,10 +397,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { - if rt.IsDynamic() { +func handlerFromRoute( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsRouterInteval time.Duration, + statusRecorder *peer.Status, + wgInterface iface.IWGIface, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) RouteHandler { + switch handlerType(rt, useNewDNSRoute) { + case handlerTypeDomain: + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerStore, + ) + case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) + return dynamic.NewRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouterInteval, + statusRecorder, + wgInterface, + fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), + ) + default: + return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) +} + +func handlerType(rt *route.Route, useNewDNSRoute bool) int { + if !rt.IsDynamic() { + return handlerTypeStatic + } + + if useNewDNSRoute { + return handlerTypeDomain + } + return handlerTypeDynamic } diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 583156e4d..56fcf1613 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -1,6 +1,7 @@ package routemanager import ( + "fmt" "net/netip" "testing" "time" @@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute: "route1", expectedRouteID: "route1", }, + { + name: "relayed routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "p2p routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, { name: "current route with bad score should be changed to route with better score", statuses: map[route.ID]routerPeerStatus{ @@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, } + // fill the test data with random routes + for _, tc := range testCases { + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p1_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p2_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { currentRoute := &route.Route{ diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go new file mode 100644 index 000000000..10cb03f1d --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -0,0 +1,356 @@ +package dnsinterceptor + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" +) + +type domainMap map[domain.Domain][]netip.Prefix + +type DnsInterceptor struct { + mu sync.RWMutex + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter + statusRecorder *peer.Status + dnsServer nbdns.Server + currentPeerKey string + interceptedDomains domainMap + peerStore *peerstore.Store +} + +func New( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + statusRecorder *peer.Status, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) *DnsInterceptor { + return &DnsInterceptor{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + statusRecorder: statusRecorder, + dnsServer: dnsServer, + interceptedDomains: make(domainMap), + peerStore: peerStore, + } +} + +func (d *DnsInterceptor) String() string { + return d.route.Domains.SafeString() +} + +func (d *DnsInterceptor) AddRoute(context.Context) error { + d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + return nil +} + +func (d *DnsInterceptor) RemoveRoute() error { + d.mu.Lock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + } + for _, domain := range d.route.Domains { + d.statusRecorder.DeleteResolvedDomainsStates(domain) + } + + clear(d.interceptedDomains) + d.mu.Unlock() + + d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) + + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + } + } + + d.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) RemoveAllowedIPs() error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for _, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + + d.currentPeerKey = "" + return nberrors.FormatErrorOrNil(merr) +} + +// ServeDNS implements the dns.Handler interface +func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + log.Tracef("received DNS request for domain=%s type=%v class=%v", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) + + d.continueToNextHandler(w, r, "no current peer key") + return + } + + upstreamIP, err := d.getUpstreamIP(peerKey) + if err != nil { + log.Errorf("failed to get upstream IP: %v", err) + d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) + return + } + + client := &dns.Client{ + Timeout: 5 * time.Second, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) + + if err != nil { + log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + reply.Id = r.Id + if err := d.writeMsg(w, reply); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } +} + +// continueToNextHandler signals the handler chain to try the next handler +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) + + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + // Set Zero bit to signal handler chain to continue + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed writing DNS continue response: %v", err) + } +} + +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { + peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) + if !exists { + return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + } + return peerAllowedIP, nil +} + +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { + if r == nil { + return fmt.Errorf("received nil DNS message") + } + + if len(r.Answer) > 0 && len(r.Question) > 0 { + origPattern := "" + if writer, ok := w.(*nbdns.ResponseWriterChain); ok { + origPattern = writer.GetOrigPattern() + } + + resolvedDomain := domain.Domain(r.Question[0].Name) + + // already punycode via RegisterHandler() + originalDomain := domain.Domain(origPattern) + if originalDomain == "" { + originalDomain = resolvedDomain + } + + var newPrefixes []netip.Prefix + for _, answer := range r.Answer { + var ip netip.Addr + switch rr := answer.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip, ip.BitLen()) + newPrefixes = append(newPrefixes, prefix) + } + + if len(newPrefixes) > 0 { + if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { + log.Errorf("failed to update domain prefixes: %v", err) + } + } + } + + if err := w.WriteMsg(r); err != nil { + return fmt.Errorf("failed to write DNS response: %v", err) + } + + return nil +} + +func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { + d.mu.Lock() + defer d.mu.Unlock() + + oldPrefixes := d.interceptedDomains[resolvedDomain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) + + var merr *multierror.Error + + // Add new prefixes + for _, prefix := range toAdd { + if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) + continue + } + + if d.currentPeerKey == "" { + continue + } + if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != d.currentPeerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + resolvedDomain.SafeString(), + ref.Out, + ) + } + } + + if !d.route.KeepRoute { + // Remove old prefixes + for _, prefix := range toRemove { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + } + + // Update domain prefixes using resolved domain as key + if len(toAdd) > 0 || len(toRemove) > 0 { + d.interceptedDomains[resolvedDomain] = newPrefixes + originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) + + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) + } + if len(toRemove) > 0 { + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index ac94d4a5c..a0fff7713 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -74,11 +74,7 @@ func NewRoute( } func (r *Route) String() string { - s, err := r.route.Domains.String() - if err != nil { - return r.route.Domains.PunycodeString() - } - return s + return r.route.Domains.SafeString() } func (r *Route) AddRoute(ctx context.Context) error { @@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) r.dynamicDomains[domain] = updatedPrefixes - r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes) + r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0a1c7dc56..6f73fb166 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,12 +12,16 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -32,16 +36,33 @@ import ( // Manager is a route manager interface type Manager interface { - Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector + GetClientRoutes() route.HAMap + GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error Stop(stateManager *statemanager.Manager) } +type ManagerConfig struct { + Context context.Context + PublicKey string + DNSRouteInterval time.Duration + WGInterface iface.IWGIface + StatusRecorder *peer.Status + RelayManager *relayClient.Manager + InitialRoutes []*route.Route + StateManager *statemanager.Manager + DNSServer dns.Server + PeerStore *peerstore.Store + DisableClientRoutes bool + DisableServerRoutes bool +} + // DefaultManager is the default instance of a route manager type DefaultManager struct { ctx context.Context @@ -49,7 +70,7 @@ type DefaultManager struct { mux sync.Mutex clientNetworks map[route.HAUniqueID]*clientNetwork routeSelector *routeselector.RouteSelector - serverRouter serverRouter + serverRouter *serverRouter sysOps *systemops.SysOps statusRecorder *peer.Status relayMgr *relayClient.Manager @@ -59,51 +80,81 @@ type DefaultManager struct { routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration + stateManager *statemanager.Manager + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store + useNewDNSRoute bool + disableClientRoutes bool + disableServerRoutes bool } -func NewManager( - ctx context.Context, - pubKey string, - dnsRouteInterval time.Duration, - wgInterface iface.IWGIface, - statusRecorder *peer.Status, - relayMgr *relayClient.Manager, - initialRoutes []*route.Route, -) *DefaultManager { - mCTX, cancel := context.WithCancel(ctx) +func NewManager(config ManagerConfig) *DefaultManager { + mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(wgInterface, notifier) + sysOps := systemops.NewSysOps(config.WGInterface, notifier) dm := &DefaultManager{ - ctx: mCTX, - stop: cancel, - dnsRouteInterval: dnsRouteInterval, - clientNetworks: make(map[route.HAUniqueID]*clientNetwork), - relayMgr: relayMgr, - routeSelector: routeselector.NewRouteSelector(), - sysOps: sysOps, - statusRecorder: statusRecorder, - wgInterface: wgInterface, - pubKey: pubKey, - notifier: notifier, + ctx: mCTX, + stop: cancel, + dnsRouteInterval: config.DNSRouteInterval, + clientNetworks: make(map[route.HAUniqueID]*clientNetwork), + relayMgr: config.RelayManager, + sysOps: sysOps, + statusRecorder: config.StatusRecorder, + wgInterface: config.WGInterface, + pubKey: config.PublicKey, + notifier: notifier, + stateManager: config.StateManager, + dnsServer: config.DNSServer, + peerStore: config.PeerStore, + disableClientRoutes: config.DisableClientRoutes, + disableServerRoutes: config.DisableServerRoutes, } - dm.routeRefCounter = refcounter.New( + // don't proceed with client routes if it is disabled + if config.DisableClientRoutes { + return dm + } + + dm.setupRefCounters() + + if runtime.GOOS == "android" { + cr := dm.initialClientRoutes(config.InitialRoutes) + dm.notifier.SetInitialClientRoutes(cr) + } + return dm +} + +func (m *DefaultManager) setupRefCounters() { + m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { - return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) }, func(prefix netip.Prefix, _ struct{}) error { - return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) + return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) }, ) - dm.allowedIPsRefCounter = refcounter.New( + if netstack.IsEnabled() { + m.routeRefCounter = refcounter.New( + func(netip.Prefix, struct{}) (struct{}, error) { + return struct{}{}, refcounter.ErrIgnore + }, + func(netip.Prefix, struct{}) error { + return nil + }, + ) + } + + m.allowedIPsRefCounter = refcounter.New( func(prefix netip.Prefix, peerKey string) (string, error) { // save peerKey to use it in the remove function - return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String()) + return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String()) }, func(prefix netip.Prefix, peerKey string) error { - if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { + if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } @@ -112,17 +163,13 @@ func NewManager( return nil }, ) - - if runtime.GOOS == "android" { - cr := dm.clientRoutes(initialRoutes) - dm.notifier.SetInitialClientRoutes(cr) - } - return dm } // Init sets up the routing -func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - if nbnet.CustomRoutingDisabled() { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + m.routeSelector = m.initSelector() + + if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { return nil, nil, nil } @@ -137,15 +184,46 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } + log.Info("Routing setup complete") return beforePeerHook, afterPeerHook, nil } +func (m *DefaultManager) initSelector() *routeselector.RouteSelector { + var state *SelectorState + m.stateManager.RegisterState(state) + + // restore selector state if it exists + if err := m.stateManager.LoadState(state); err != nil { + log.Warnf("failed to load state: %v", err) + return routeselector.NewRouteSelector() + } + + if state := m.stateManager.GetState(state); state != nil { + if selector, ok := state.(*SelectorState); ok { + return (*routeselector.RouteSelector)(selector) + } + + log.Warnf("failed to convert state with type %T to SelectorState", state) + } + + return routeselector.NewRouteSelector() +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { + if m.disableServerRoutes { + log.Info("server routes are disabled") + return nil + } + + if firewall == nil { + return errors.New("firewall manager is not set") + } + var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { @@ -172,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } } - if !nbnet.CustomRoutingDisabled() { + if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { @@ -181,33 +259,43 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } m.ctx = nil + + m.mux.Lock() + defer m.mux.Unlock() + m.clientRoutes = nil } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil default: - m.mux.Lock() - defer m.mux.Unlock() + } - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + m.mux.Lock() + defer m.mux.Unlock() + m.useNewDNSRoute = useNewDNSRoute + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + + if !m.disableClientRoutes { filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) m.updateClientNetworks(updateSerial, filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes) - - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) - } - } - - return newServerRoutesMap, newClientRoutesIDMap, nil } + + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err + } + } + + m.clientRoutes = newClientRoutesIDMap + + return nil } // SetRouteChangeListener set RouteListener for route change Notifier @@ -225,9 +313,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { return m.routeSelector } -// GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { - return m.clientNetworks +// GetClientRoutes returns most recent list of clientRoutes received from the Management Service +func (m *DefaultManager) GetClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return maps.Clone(m.clientRoutes) +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + m.mux.Lock() + defer m.mux.Unlock() + + routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes)) + for id, v := range m.clientRoutes { + routes[id.NetID()] = v + } + return routes } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones @@ -247,11 +350,26 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher := newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) } + + if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { + log.Errorf("failed to update state: %v", err) + } } // stopObsoleteClients stops the client network watcher for the networks that are not in the new list @@ -272,7 +390,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher = newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -315,7 +444,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] return newServerRoutesMap, newClientRoutesIDMap } -func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { +func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index e669bc44a..318ef5ae5 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,9 +424,14 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) + routeManager := NewManager(ManagerConfig{ + Context: ctx, + PublicKey: localPeerKey, + WGInterface: wgInterface, + StatusRecorder: statusRecorder, + }) - _, _, err = routeManager.Init(nil) + _, _, err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) @@ -436,11 +441,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected @@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { - sr := routeManager.serverRouter.(*defaultServerRouter) - require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") + require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") } }) } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 503185f03..64fdffceb 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -2,7 +2,6 @@ package routemanager import ( "context" - "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -15,13 +14,15 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) - TriggerSelectionFunc func(haMap route.HAMap) - GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func(manager *statemanager.Manager) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + TriggerSelectionFunc func(haMap route.HAMap) + GetRouteSelectorFunc func() *routeselector.RouteSelector + GetClientRoutesFunc func() route.HAMap + GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route + StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } @@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") + return nil } func (m *MockManager) TriggerSelection(networks route.HAMap) { @@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } +// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +func (m *MockManager) GetClientRoutes() route.HAMap { + if m.GetClientRoutesFunc != nil { + return m.GetClientRoutesFunc() + } + return nil +} + +// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface +func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + if m.GetClientRoutesWithNetIDFunc != nil { + return m.GetClientRoutesWithNetIDFunc() + } + return nil +} + // Start mock implementation of Start from Manager interface func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 0e230ef40..27a724f50 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error type Counter[Key comparable, I, O any] struct { // refCountMap keeps track of the reference Ref for keys refCountMap map[Key]Ref[O] - refCountMu sync.Mutex + mu sync.Mutex // idMap keeps track of the keys associated with an ID for removal idMap map[string][]Key - idMu sync.Mutex add AddFunc[Key, I, O] remove RemoveFunc[Key, O] } @@ -72,13 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } // LoadData loads the data from the existing counter +// The passed counter should not be used any longer after calling this function. func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + existingCounter.mu.Lock() + defer existingCounter.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -87,8 +87,8 @@ func (rm *Counter[Key, I, O]) LoadData( // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() ref, ok := rm.refCountMap[key] return ref, ok @@ -97,9 +97,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { // Increment increments the reference count for the given key. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.increment(key, in) +} + +func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) { ref := rm.refCountMap[key] logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) @@ -126,10 +130,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { // IncrementWithID increments the reference count for the given key and groups it under the given ID. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() - ref, err := rm.Increment(key, in) + ref, err := rm.increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } @@ -141,9 +145,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], // Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.decrement(key) +} +func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) { ref, ok := rm.refCountMap[key] if !ok { logCallerF("No reference found for key %v", key) @@ -168,12 +175,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { // DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for _, key := range rm.idMap[id] { - if _, err := rm.Decrement(key); err != nil { + if _, err := rm.decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -184,10 +191,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { // Flush removes all references and calls RemoveFunc for each key. func (rm *Counter[Key, I, O]) Flush() error { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for key := range rm.refCountMap { @@ -206,10 +211,8 @@ func (rm *Counter[Key, I, O]) Flush() error { // Clear removes all references without calling RemoveFunc. func (rm *Counter[Key, I, O]) Clear() { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() clear(rm.refCountMap) clear(rm.idMap) @@ -217,10 +220,8 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` @@ -233,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the json.Unmarshaler interface for Counter. func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + rm.mu.Lock() + defer rm.mu.Unlock() + var temp struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` @@ -243,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { rm.refCountMap = temp.RefCountMap rm.idMap = temp.IDMap + if temp.RefCountMap == nil { + temp.RefCountMap = map[Key]Ref[O]{} + } + if temp.IDMap == nil { + temp.IDMap = map[string][]Key{} + } + return nil } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go deleted file mode 100644 index 368421eb7..000000000 --- a/client/internal/routemanager/server.go +++ /dev/null @@ -1,9 +0,0 @@ -package routemanager - -import "github.com/netbirdio/netbird/route" - -type serverRouter interface { - updateRoutes(map[route.ID]*route.Route) error - removeFromServerNetwork(*route.Route) error - cleanUp() -} diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index c75a0a7f2..e9cfa0826 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -9,8 +9,19 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/route" ) -func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { +type serverRouter struct { +} + +func (r serverRouter) cleanUp() { +} + +func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { + return nil +} + +func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index ef38d5707..b60cb318e 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -17,7 +17,7 @@ import ( "github.com/netbirdio/netbird/route" ) -type defaultServerRouter struct { +type serverRouter struct { mux sync.Mutex ctx context.Context routes map[route.ID]*route.Route @@ -26,8 +26,8 @@ type defaultServerRouter struct { statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { - return &defaultServerRouter{ +func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { + return &serverRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), firewall: firewall, @@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f }, nil } -func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { +func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { @@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) return nil } -func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { - select { - case <-m.ctx.Done(): +func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { + if m.ctx.Err() != nil { log.Infof("Not removing from server network because context is done") return m.ctx.Err() - default: - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { - return fmt.Errorf("remove routing rules: %w", err) - } - - delete(m.routes, route.ID) - - state := m.statusRecorder.GetLocalPeerState() - delete(state.Routes, route.Network.String()) - m.statusRecorder.UpdateLocalPeerState(state) - - return nil } + + m.mux.Lock() + defer m.mux.Unlock() + + routerPair, err := routeToRouterPair(route) + if err != nil { + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.RemoveNatRule(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) + } + + delete(m.routes, route.ID) + + state := m.statusRecorder.GetLocalPeerState() + delete(state.Routes, route.Network.String()) + m.statusRecorder.UpdateLocalPeerState(state) + + return nil } -func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { - select { - case <-m.ctx.Done(): +func (m *serverRouter) addToServerNetwork(route *route.Route) error { + if m.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") return m.ctx.Err() - default: - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.AddNatRule(routerPair) - if err != nil { - return fmt.Errorf("insert routing rules: %w", err) - } - - m.routes[route.ID] = route - - state := m.statusRecorder.GetLocalPeerState() - if state.Routes == nil { - state.Routes = map[string]struct{}{} - } - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - state.Routes[routeStr] = struct{}{} - - m.statusRecorder.UpdateLocalPeerState(state) - - return nil } + + m.mux.Lock() + defer m.mux.Unlock() + + routerPair, err := routeToRouterPair(route) + if err != nil { + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.AddNatRule(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) + } + + m.routes[route.ID] = route + + state := m.statusRecorder.GetLocalPeerState() + if state.Routes == nil { + state.Routes = map[string]struct{}{} + } + + routeStr := route.Network.String() + if route.IsDynamic() { + routeStr = route.Domains.SafeString() + } + state.Routes[routeStr] = struct{}{} + + m.statusRecorder.UpdateLocalPeerState(state) + + return nil } -func (m *defaultServerRouter) cleanUp() { +func (m *serverRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { diff --git a/client/internal/routemanager/state.go b/client/internal/routemanager/state.go new file mode 100644 index 000000000..a45c32b50 --- /dev/null +++ b/client/internal/routemanager/state.go @@ -0,0 +1,19 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/client/internal/routeselector" +) + +type SelectorState routeselector.RouteSelector + +func (s *SelectorState) Name() string { + return "routeselector_state" +} + +func (s *SelectorState) MarshalJSON() ([]byte, error) { + return (*routeselector.RouteSelector)(s).MarshalJSON() +} + +func (s *SelectorState) UnmarshalJSON(data []byte) error { + return (*routeselector.RouteSelector)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 425908922..8e158711e 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa51..31b7f3ac2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -57,30 +58,30 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager) - return nexthop, err }, - func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table - // it could be already gone - r.updateState(stateManager) - - return r.removeFromRouteTable(prefix, nexthop) - }, + r.removeFromRouteTable, ) + if netstack.IsEnabled() { + refCounter = refcounter.New( + func(netip.Prefix, struct{}) (Nexthop, error) { + return Nexthop{}, refcounter.ErrIgnore + }, + func(netip.Prefix, Nexthop) error { + return nil + }, + ) + } + r.refCounter = refCounter - return r.setupHooks(initAddresses) + return r.setupHooks(initAddresses, stateManager) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } @@ -336,7 +337,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -347,6 +348,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("adding route reference: %v", err) } + r.updateState(stateManager) + return nil } afterHook := func(connID nbnet.ConnectionID) error { @@ -354,6 +357,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("remove route reference: %w", err) } + r.updateState(stateManager) + return nil } @@ -532,14 +537,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 0124fd95e..1da92cc80 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() } // setIsLegacy sets the legacy routing setup @@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } - - originalValues, err := sysctl.Setup(r.wgInterface) - if err != nil { - log.Errorf("Error setting up sysctl: %v", err) - sysctlFailed = true - } - originalSysctl = originalValues - defer func() { if err != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { @@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } } + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + originalValues, err := sysctl.Setup(r.wgInterface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + return nil, nil, nil } @@ -266,7 +266,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("netlink add route: %w", err) } @@ -289,7 +289,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error { Dst: ipNet, } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("netlink add unreachable route: %w", err) } @@ -312,7 +312,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.ENOENT) && - !errors.Is(err, syscall.EAFNOSUPPORT) { + !isOpErr(err) { return fmt.Errorf("netlink remove unreachable route: %w", err) } @@ -338,7 +338,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !isOpErr(err) { return fmt.Errorf("netlink remove route: %w", err) } @@ -362,7 +362,7 @@ func flushRoutes(tableID, family int) error { routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } } - if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteDel(&routes[i]); err != nil && !isOpErr(err) { result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) } } @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !isOpErr(err) { return fmt.Errorf("remove routing rule: %w", err) } @@ -509,3 +509,13 @@ func hasSeparateRouting() ([]netip.Prefix, error) { } return nil, ErrRoutingIsSeparate } + +func isOpErr(err error) bool { + // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported + if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { + log.Debugf("route operation not supported: %v", err) + return true + } + + return false +} diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index b1732a080..ad325e123 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -230,10 +230,13 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI if idx != 0 { intf, err := net.InterfaceByIndex(idx) if err != nil { - return update, fmt.Errorf("get interface name: %w", err) + log.Warnf("failed to get interface name for index %d: %v", idx, err) + update.Interface = &net.Interface{ + Index: idx, + } + } else { + update.Interface = intf } - - update.Interface = intf } log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 00128a27b..2874604fd 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -1,8 +1,10 @@ package routeselector import ( + "encoding/json" "fmt" "slices" + "sync" "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" @@ -12,6 +14,7 @@ import ( ) type RouteSelector struct { + mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool } @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector { // SelectRoutes updates the selected routes based on the provided route IDs. func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if !appendRoute { rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al // SelectAllRoutes sets the selector to select all routes. func (rs *RouteSelector) SelectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() { // DeselectRoutes removes specific routes from the selection. // If the selector is in "select all" mode, it will transition to "select specific" mode. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.selectAll { rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. // DeselectAllRoutes deselects all routes, effectively disabling route selection. func (rs *RouteSelector) DeselectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} } // IsSelected checks if a specific route is selected. func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return true } @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { // FilterSelected removes unselected routes from the provided map. func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return maps.Clone(routes) } @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { } return filtered } + +// MarshalJSON implements the json.Marshaler interface +func (rs *RouteSelector) MarshalJSON() ([]byte, error) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + return json.Marshal(struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + }{ + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +// If the JSON is empty or null, it will initialize like a NewRouteSelector. +func (rs *RouteSelector) UnmarshalJSON(data []byte) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Check for null or empty JSON + if len(data) == 0 || string(data) == "null" { + rs.selectedRoutes = map[route.NetID]struct{}{} + rs.selectAll = true + return nil + } + + var temp struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + rs.selectedRoutes = temp.SelectedRoutes + rs.selectAll = temp.SelectAll + + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } + + return nil +} diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 7df433f92..b1671f254 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) { "route2|192.168.0.0/16": {}, }, filtered) } + +func TestRouteSelector_NewRoutesBehavior(t *testing.T) { + initialRoutes := []route.NetID{"route1", "route2", "route3"} + newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} + + tests := []struct { + name string + initialState func(rs *routeselector.RouteSelector) error // Setup initial state + wantNewSelected []route.NetID // Expected selected routes after new routes appear + }{ + { + name: "New routes with initial selectAll state", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + // When selectAll is true, all routes including new ones should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, + }, + { + name: "New routes after specific selection", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) + }, + // When specific routes were selected, new routes should remain unselected + wantNewSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "New routes after deselect all", + initialState: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + // After deselect all, new routes should remain unselected + wantNewSelected: []route.NetID{}, + }, + { + name: "New routes after deselecting specific routes", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) + }, + // After deselecting specific routes, new routes should remain unselected + wantNewSelected: []route.NetID{"route2", "route3"}, + }, + { + name: "New routes after selecting with append", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) + }, + // When routes were appended, new routes should remain unselected + wantNewSelected: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + // Setup initial state + err := tt.initialState(rs) + require.NoError(t, err) + + // Verify selection state with new routes + for _, id := range newRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id), + "Route %s selection state incorrect", id) + } + + // Additional verification using FilterSelected + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + "route5|192.168.1.0/24": {}, + } + + filtered := rs.FilterSelected(routes) + expectedLen := len(tt.wantNewSelected) + assert.Equal(t, expectedLen, len(filtered), + "FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen) + }) + } +} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index a5a14f807..9a99c76f1 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -16,14 +16,39 @@ import ( "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/util" +) + +const ( + errStateNotRegistered = "state %s not registered" + errLoadStateFile = "load state file: %w" ) // State interface defines the methods that all state types must implement type State interface { Name() string +} + +// CleanableState interface extends State with cleanup capability +type CleanableState interface { + State Cleanup() error } +// RawState wraps raw JSON data for unregistered states +type RawState struct { + data json.RawMessage +} + +func (r *RawState) Name() string { + return "" // This is a placeholder implementation +} + +// MarshalJSON implements json.Marshaler to preserve the original JSON +func (r *RawState) MarshalJSON() ([]byte, error) { + return r.data, nil +} + // Manager handles the persistence and management of various states type Manager struct { mu sync.Mutex @@ -73,15 +98,15 @@ func (m *Manager) Stop(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if m.cancel == nil { + return nil + } + m.cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil @@ -139,7 +164,7 @@ func (m *Manager) setState(name string, state State) error { defer m.mu.Unlock() if _, exists := m.states[name]; !exists { - return fmt.Errorf("state %s not registered", name) + return fmt.Errorf(errStateNotRegistered, name) } m.states[name] = state @@ -148,6 +173,63 @@ func (m *Manager) setState(name string, state State) error { return nil } +// DeleteStateByName handles deletion of states without cleanup. +// It doesn't require the state to be registered. +func (m *Manager) DeleteStateByName(stateName string) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile(false) + if err != nil { + return fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return nil + } + + if _, exists := rawStates[stateName]; !exists { + return fmt.Errorf("state %s not found", stateName) + } + + // Mark state as deleted by setting it to nil and marking it dirty + m.states[stateName] = nil + m.dirty[stateName] = struct{}{} + + return nil +} + +// DeleteAllStates removes all states. +func (m *Manager) DeleteAllStates() (int, error) { + if m == nil { + return 0, nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile(false) + if err != nil { + return 0, fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return 0, nil + } + + count := len(rawStates) + + // Mark all states as deleted and dirty + for name := range rawStates { + m.states[name] = nil + m.dirty[name] = struct{}{} + } + + return count, nil +} + func (m *Manager) periodicStateSave(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -178,25 +260,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + bs, err := marshalWithPanicRecovery(m.states) + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - + start := time.Now() go func() { - data, err := json.MarshalIndent(m.states, "", " ") - if err != nil { - done <- fmt.Errorf("marshal states: %w", err) - return - } - - // nolint:gosec - if err := os.WriteFile(m.filePath, data, 0640); err != nil { - done <- fmt.Errorf("write state file: %w", err) - return - } - - done <- nil + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { @@ -208,63 +283,175 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) return nil } -// loadState loads the existing state from the state file -func (m *Manager) loadState() error { +// loadStateFile reads and unmarshals the state file into a map of raw JSON messages +func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { log.Debug("state file does not exist") - return nil + return nil, nil // nolint:nilnil } - return fmt.Errorf("read state file: %w", err) + return nil, fmt.Errorf("read state file: %w", err) } var rawStates map[string]json.RawMessage if err := json.Unmarshal(data, &rawStates); err != nil { - log.Warn("State file appears to be corrupted, attempting to delete it") - if err := os.Remove(m.filePath); err != nil { - log.Errorf("Failed to delete corrupted state file: %v", err) - } else { - log.Info("State file deleted") + if deleteCorrupt { + log.Warn("State file appears to be corrupted, attempting to delete it", err) + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") + } } - return fmt.Errorf("unmarshal states: %w", err) + return nil, fmt.Errorf("unmarshal states: %w", err) } - var merr *multierror.Error + return rawStates, nil +} - for name, rawState := range rawStates { - stateType, ok := m.stateTypes[name] - if !ok { - merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) - continue - } +// loadSingleRawState unmarshals a raw state into a concrete state object +func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { + stateType, ok := m.stateTypes[name] + if !ok { + return nil, fmt.Errorf(errStateNotRegistered, name) + } - if string(rawState) == "null" { - continue - } + if string(rawState) == "null" { + return nil, nil //nolint:nilnil + } - statePtr := reflect.New(stateType).Interface().(State) - if err := json.Unmarshal(rawState, statePtr); err != nil { - merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) - continue - } + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + return nil, fmt.Errorf("unmarshal state %s: %w", name, err) + } - m.states[name] = statePtr + return statePtr, nil +} + +// LoadState loads a specific state from the state file +func (m *Manager) LoadState(state State) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile(false) + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + name := state.Name() + rawState, exists := rawStates[name] + if !exists { + return nil + } + + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + m.states[name] = loadedState + if loadedState != nil { log.Debugf("loaded state: %s", name) } - return nberrors.FormatErrorOrNil(merr) + return nil } -// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. -// If the cleanup is successful, the state is marked for deletion. +// cleanupSingleState handles the cleanup of a specific state and returns any error. +// The caller must hold the mutex. +func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error { + // For unregistered states, preserve the raw JSON + if _, registered := m.stateTypes[name]; !registered { + m.states[name] = &RawState{data: rawState} + return nil + } + + // Load the state + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + if loadedState == nil { + return nil + } + + // Check if state supports cleanup + cleanableState, isCleanable := loadedState.(CleanableState) + if !isCleanable { + // If it doesn't support cleanup, keep it as-is + m.states[name] = loadedState + return nil + } + + // Perform cleanup + log.Infof("cleaning up state %s", name) + if err := cleanableState.Cleanup(); err != nil { + // On cleanup error, preserve the state + m.states[name] = loadedState + return fmt.Errorf("cleanup state: %w", err) + } + + // Successfully cleaned up - mark for deletion + m.states[name] = nil + m.dirty[name] = struct{}{} + return nil +} + +// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState. +// Returns an error if the state doesn't exist, isn't registered, or cleanup fails. +func (m *Manager) CleanupStateByName(name string) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Check if state is registered + if _, registered := m.stateTypes[name]; !registered { + return fmt.Errorf(errStateNotRegistered, name) + } + + // Load raw states from file + rawStates, err := m.loadStateFile(false) + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + // Check if state exists in file + rawState, exists := rawStates[name] + if !exists { + return nil + } + + if err := m.cleanupSingleState(name, rawState); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + + return nil +} + +// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. +// Unregistered states are preserved in their original state. func (m *Manager) PerformCleanup() error { if m == nil { return nil @@ -273,26 +460,63 @@ func (m *Manager) PerformCleanup() error { m.mu.Lock() defer m.mu.Unlock() - if err := m.loadState(); err != nil { - log.Warnf("Failed to load state during cleanup: %v", err) + // Load raw states from file + rawStates, err := m.loadStateFile(true) + if err != nil { + return fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return nil } var merr *multierror.Error - for name, state := range m.states { - if state == nil { - // If no state was found in the state file, we don't mark the state dirty nor return an error - continue - } - log.Infof("client was not shut down properly, cleaning up %s", name) - if err := state.Cleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) - } else { - // mark for deletion on cleanup success - m.states[name] = nil - m.dirty[name] = struct{}{} + // Process each state in the file + for name, rawState := range rawStates { + if err := m.cleanupSingleState(name, rawState); err != nil { + merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err)) } } return nberrors.FormatErrorOrNil(merr) } + +// GetSavedStateNames returns all state names that are currently saved in the state file. +func (m *Manager) GetSavedStateNames() ([]string, error) { + if m == nil { + return nil, nil + } + + rawStates, err := m.loadStateFile(false) + if err != nil { + return nil, fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return nil, nil + } + + var states []string + for name, state := range rawStates { + if len(state) != 0 && string(state) != "null" { + states = append(states, name) + } + } + + return states, nil +} + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.Marshal(v) + }() + + return bs, err +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 96d6a9f12..d232e5f0c 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -1,35 +1,16 @@ package statemanager import ( + "github.com/netbirdio/netbird/client/configs" "os" "path/filepath" - "runtime" - - log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +// It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - var path string - - switch runtime.GOOS { - case "windows": - path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") - case "darwin", "linux": - path = "/var/lib/netbird/state.json" - case "freebsd", "openbsd", "netbsd", "dragonfly": - path = "/var/db/netbird/state.json" - // ios/android don't need state - default: - return "" + if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" { + return path } - - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) - return "" - } - - return path + return filepath.Join(configs.StateDir, "state.json") } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 9d65bdbe0..befce56a2 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -59,6 +59,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string + stateFile string recorder *peer.Status ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex @@ -73,9 +74,10 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { +func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { return &Client{ cfgFile: cfgFile, + stateFile: stateFile, deviceName: deviceName, osName: osName, osVersion: osVersion, @@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error { log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: c.cfgFile, + StateFilePath: c.stateFile, }) if err != nil { return err @@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources @@ -269,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() routeManager := engine.GetRouteManager() + routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } @@ -314,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } -func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails { +func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { domainList := make([]DomainInfo, 0) @@ -322,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom domainResp := DomainInfo{ Domain: d.SafeString(), } - if prefixes, exists := resolvedDomains[d]; exists { + + if info, exists := resolvedDomains[d]; exists { var ipStrings []string - for _, prefix := range prefixes { + for _, prefix := range info.Prefixes { ipStrings = append(ipStrings, prefix.Addr().String()) } domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") @@ -362,12 +366,12 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } @@ -389,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index b78146679..5a0abd9a7 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -10,9 +10,10 @@ type Preferences struct { } // NewPreferences create new Preferences instance -func NewPreferences(configPath string) *Preferences { +func NewPreferences(configPath string, stateFilePath string) *Preferences { ci := internal.ConfigInput{ - ConfigPath: configPath, + ConfigPath: configPath, + StateFilePath: stateFilePath, } return &Preferences{ci} } diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index aa6a475ae..7e5325a00 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -9,7 +9,8 @@ import ( func TestPreferences_DefaultValues(t *testing.T) { cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) defaultVar, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read default value: %s", err) @@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) { exampleString := "exampleString" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleString) resp, err := p.GetAdminURL() @@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) { exampleURL := "https://myurl.com:443" examplePresharedKey := "topsecret" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleURL) p.SetManagementURL(exampleURL) @@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) { t.Fatalf("failed to save changes: %s", err) } - p = NewPreferences(cfgFile) + p = NewPreferences(cfgFile, stateFile) resp, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read admin url: %s", err) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index b942d8b6e..659277570 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.23.4 // source: daemon.proto package proto @@ -122,6 +122,10 @@ type LoginRequest struct { ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` DnsRouteInterval *durationpb.Duration `protobuf:"bytes,19,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + DisableClientRoutes *bool `protobuf:"varint,20,opt,name=disable_client_routes,json=disableClientRoutes,proto3,oneof" json:"disable_client_routes,omitempty"` + DisableServerRoutes *bool `protobuf:"varint,21,opt,name=disable_server_routes,json=disableServerRoutes,proto3,oneof" json:"disable_server_routes,omitempty"` + DisableDns *bool `protobuf:"varint,22,opt,name=disable_dns,json=disableDns,proto3,oneof" json:"disable_dns,omitempty"` + DisableFirewall *bool `protobuf:"varint,23,opt,name=disable_firewall,json=disableFirewall,proto3,oneof" json:"disable_firewall,omitempty"` } func (x *LoginRequest) Reset() { @@ -290,6 +294,34 @@ func (x *LoginRequest) GetDnsRouteInterval() *durationpb.Duration { return nil } +func (x *LoginRequest) GetDisableClientRoutes() bool { + if x != nil && x.DisableClientRoutes != nil { + return *x.DisableClientRoutes + } + return false +} + +func (x *LoginRequest) GetDisableServerRoutes() bool { + if x != nil && x.DisableServerRoutes != nil { + return *x.DisableServerRoutes + } + return false +} + +func (x *LoginRequest) GetDisableDns() bool { + if x != nil && x.DisableDns != nil { + return *x.DisableDns + } + return false +} + +func (x *LoginRequest) GetDisableFirewall() bool { + if x != nil && x.DisableFirewall != nil { + return *x.DisableFirewall + } + return false +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -908,7 +940,7 @@ type PeerState struct { BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` } @@ -1043,9 +1075,9 @@ func (x *PeerState) GetRosenpassEnabled() bool { return false } -func (x *PeerState) GetRoutes() []string { +func (x *PeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1076,7 +1108,7 @@ type LocalPeerState struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - Routes []string `protobuf:"bytes,7,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` } func (x *LocalPeerState) Reset() { @@ -1153,9 +1185,9 @@ func (x *LocalPeerState) GetRosenpassPermissive() bool { return false } -func (x *LocalPeerState) GetRoutes() []string { +func (x *LocalPeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1511,14 +1543,14 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState { return nil } -type ListRoutesRequest struct { +type ListNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *ListRoutesRequest) Reset() { - *x = ListRoutesRequest{} +func (x *ListNetworksRequest) Reset() { + *x = ListNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1526,13 +1558,13 @@ func (x *ListRoutesRequest) Reset() { } } -func (x *ListRoutesRequest) String() string { +func (x *ListNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesRequest) ProtoMessage() {} +func (*ListNetworksRequest) ProtoMessage() {} -func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1544,21 +1576,21 @@ func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead. -func (*ListRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. +func (*ListNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{19} } -type ListRoutesResponse struct { +type ListNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` + Routes []*Network `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` } -func (x *ListRoutesResponse) Reset() { - *x = ListRoutesResponse{} +func (x *ListNetworksResponse) Reset() { + *x = ListNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1566,13 +1598,13 @@ func (x *ListRoutesResponse) Reset() { } } -func (x *ListRoutesResponse) String() string { +func (x *ListNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesResponse) ProtoMessage() {} +func (*ListNetworksResponse) ProtoMessage() {} -func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1584,30 +1616,30 @@ func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead. -func (*ListRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. +func (*ListNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{20} } -func (x *ListRoutesResponse) GetRoutes() []*Route { +func (x *ListNetworksResponse) GetRoutes() []*Network { if x != nil { return x.Routes } return nil } -type SelectRoutesRequest struct { +type SelectNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"` - Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` - All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` + NetworkIDs []string `protobuf:"bytes,1,rep,name=networkIDs,proto3" json:"networkIDs,omitempty"` + Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` + All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` } -func (x *SelectRoutesRequest) Reset() { - *x = SelectRoutesRequest{} +func (x *SelectNetworksRequest) Reset() { + *x = SelectNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1615,13 +1647,13 @@ func (x *SelectRoutesRequest) Reset() { } } -func (x *SelectRoutesRequest) String() string { +func (x *SelectNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesRequest) ProtoMessage() {} +func (*SelectNetworksRequest) ProtoMessage() {} -func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1633,40 +1665,40 @@ func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead. -func (*SelectRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. +func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{21} } -func (x *SelectRoutesRequest) GetRouteIDs() []string { +func (x *SelectNetworksRequest) GetNetworkIDs() []string { if x != nil { - return x.RouteIDs + return x.NetworkIDs } return nil } -func (x *SelectRoutesRequest) GetAppend() bool { +func (x *SelectNetworksRequest) GetAppend() bool { if x != nil { return x.Append } return false } -func (x *SelectRoutesRequest) GetAll() bool { +func (x *SelectNetworksRequest) GetAll() bool { if x != nil { return x.All } return false } -type SelectRoutesResponse struct { +type SelectNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *SelectRoutesResponse) Reset() { - *x = SelectRoutesResponse{} +func (x *SelectNetworksResponse) Reset() { + *x = SelectNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1674,13 +1706,13 @@ func (x *SelectRoutesResponse) Reset() { } } -func (x *SelectRoutesResponse) String() string { +func (x *SelectNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesResponse) ProtoMessage() {} +func (*SelectNetworksResponse) ProtoMessage() {} -func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1692,8 +1724,8 @@ func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead. -func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. +func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{22} } @@ -1744,20 +1776,20 @@ func (x *IPList) GetIps() []string { return nil } -type Route struct { +type Network struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Range string `protobuf:"bytes,2,opt,name=range,proto3" json:"range,omitempty"` Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` Domains []string `protobuf:"bytes,4,rep,name=domains,proto3" json:"domains,omitempty"` ResolvedIPs map[string]*IPList `protobuf:"bytes,5,rep,name=resolvedIPs,proto3" json:"resolvedIPs,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } -func (x *Route) Reset() { - *x = Route{} +func (x *Network) Reset() { + *x = Network{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1765,13 +1797,13 @@ func (x *Route) Reset() { } } -func (x *Route) String() string { +func (x *Network) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Route) ProtoMessage() {} +func (*Network) ProtoMessage() {} -func (x *Route) ProtoReflect() protoreflect.Message { +func (x *Network) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1783,40 +1815,40 @@ func (x *Route) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Route.ProtoReflect.Descriptor instead. -func (*Route) Descriptor() ([]byte, []int) { +// Deprecated: Use Network.ProtoReflect.Descriptor instead. +func (*Network) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{24} } -func (x *Route) GetID() string { +func (x *Network) GetID() string { if x != nil { return x.ID } return "" } -func (x *Route) GetNetwork() string { +func (x *Network) GetRange() string { if x != nil { - return x.Network + return x.Range } return "" } -func (x *Route) GetSelected() bool { +func (x *Network) GetSelected() bool { if x != nil { return x.Selected } return false } -func (x *Route) GetDomains() []string { +func (x *Network) GetDomains() []string { if x != nil { return x.Domains } return nil } -func (x *Route) GetResolvedIPs() map[string]*IPList { +func (x *Network) GetResolvedIPs() map[string]*IPList { if x != nil { return x.ResolvedIPs } @@ -2103,6 +2135,434 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{30} } +// State represents a daemon state entry +type State struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` +} + +func (x *State) Reset() { + *x = State{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *State) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*State) ProtoMessage() {} + +func (x *State) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[31] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use State.ProtoReflect.Descriptor instead. +func (*State) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{31} +} + +func (x *State) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +// ListStatesRequest is empty as it requires no parameters +type ListStatesRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *ListStatesRequest) Reset() { + *x = ListStatesRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[32] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListStatesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListStatesRequest) ProtoMessage() {} + +func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[32] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. +func (*ListStatesRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{32} +} + +// ListStatesResponse contains a list of states +type ListStatesResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"` +} + +func (x *ListStatesResponse) Reset() { + *x = ListStatesResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListStatesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListStatesResponse) ProtoMessage() {} + +func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[33] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. +func (*ListStatesResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{33} +} + +func (x *ListStatesResponse) GetStates() []*State { + if x != nil { + return x.States + } + return nil +} + +// CleanStateRequest for cleaning states +type CleanStateRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"` + All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"` +} + +func (x *CleanStateRequest) Reset() { + *x = CleanStateRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CleanStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CleanStateRequest) ProtoMessage() {} + +func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. +func (*CleanStateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{34} +} + +func (x *CleanStateRequest) GetStateName() string { + if x != nil { + return x.StateName + } + return "" +} + +func (x *CleanStateRequest) GetAll() bool { + if x != nil { + return x.All + } + return false +} + +// CleanStateResponse contains the result of the clean operation +type CleanStateResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"` +} + +func (x *CleanStateResponse) Reset() { + *x = CleanStateResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CleanStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CleanStateResponse) ProtoMessage() {} + +func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. +func (*CleanStateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{35} +} + +func (x *CleanStateResponse) GetCleanedStates() int32 { + if x != nil { + return x.CleanedStates + } + return 0 +} + +// DeleteStateRequest for deleting states +type DeleteStateRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"` + All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"` +} + +func (x *DeleteStateRequest) Reset() { + *x = DeleteStateRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteStateRequest) ProtoMessage() {} + +func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. +func (*DeleteStateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{36} +} + +func (x *DeleteStateRequest) GetStateName() string { + if x != nil { + return x.StateName + } + return "" +} + +func (x *DeleteStateRequest) GetAll() bool { + if x != nil { + return x.All + } + return false +} + +// DeleteStateResponse contains the result of the delete operation +type DeleteStateResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"` +} + +func (x *DeleteStateResponse) Reset() { + *x = DeleteStateResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[37] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteStateResponse) ProtoMessage() {} + +func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[37] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. +func (*DeleteStateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{37} +} + +func (x *DeleteStateResponse) GetDeletedStates() int32 { + if x != nil { + return x.DeletedStates + } + return 0 +} + +type SetNetworkMapPersistenceRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` +} + +func (x *SetNetworkMapPersistenceRequest) Reset() { + *x = SetNetworkMapPersistenceRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[38] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetNetworkMapPersistenceRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetNetworkMapPersistenceRequest) ProtoMessage() {} + +func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[38] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead. +func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{38} +} + +func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +type SetNetworkMapPersistenceResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SetNetworkMapPersistenceResponse) Reset() { + *x = SetNetworkMapPersistenceResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[39] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetNetworkMapPersistenceResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetNetworkMapPersistenceResponse) ProtoMessage() {} + +func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[39] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead. +func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{39} +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -2113,7 +2573,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb0, 0x08, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xd1, 0x0a, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -2168,297 +2628,367 @@ var file_daemon_proto_rawDesc = []byte{ 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x13, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x08, 0x52, 0x10, 0x64, 0x6e, 0x73, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x13, - 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, - 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, - 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, - 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, - 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, - 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x6e, 0x73, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x22, 0xb5, 0x01, 0x0a, - 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, - 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, - 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, + 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x12, 0x37, + 0x0a, 0x15, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x48, 0x09, 0x52, + 0x13, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x88, 0x01, 0x01, 0x12, 0x37, 0x0a, 0x15, 0x64, 0x69, 0x73, 0x61, 0x62, + 0x6c, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x18, 0x15, 0x20, 0x01, 0x28, 0x08, 0x48, 0x0a, 0x52, 0x13, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x88, 0x01, 0x01, + 0x12, 0x24, 0x0a, 0x0b, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x64, 0x6e, 0x73, 0x18, + 0x16, 0x20, 0x01, 0x28, 0x08, 0x48, 0x0b, 0x52, 0x0a, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x44, 0x6e, 0x73, 0x88, 0x01, 0x01, 0x12, 0x2e, 0x0a, 0x10, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x18, 0x17, 0x20, 0x01, 0x28, 0x08, + 0x48, 0x0c, 0x52, 0x0f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, + 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, + 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, + 0x0e, 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, + 0x17, 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, + 0x13, 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, + 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x42, + 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x6e, 0x73, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, + 0x72, 0x76, 0x61, 0x6c, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x42, 0x18, + 0x0a, 0x16, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x42, 0x0e, 0x0a, 0x0c, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x64, 0x6e, 0x73, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x22, 0xb5, 0x01, + 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x24, 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, + 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, + 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, - 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, - 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, - 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, - 0x61, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, - 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, - 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, - 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb9, 0x03, - 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, - 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, - 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, - 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, - 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, - 0x55, 0x52, 0x4c, 0x12, 0x24, 0x0a, 0x0d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, - 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x69, 0x72, - 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x0d, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x12, - 0x2e, 0x0a, 0x12, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x64, 0x69, 0x73, - 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x12, - 0x2a, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, - 0x77, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65, - 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, - 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, - 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, - 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, - 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, - 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, - 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, - 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, + 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, + 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, + 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, + 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, + 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, + 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, + 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, + 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb9, + 0x03, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, + 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, + 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, + 0x46, 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, + 0x64, 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x55, 0x52, 0x4c, 0x12, 0x24, 0x0a, 0x0d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x69, + 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0d, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x2e, 0x0a, 0x12, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x64, 0x69, + 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x12, 0x2a, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, + 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, + 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, - 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, - 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, - 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, - 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, - 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, - 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, - 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, - 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, - 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, - 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, - 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, + 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, + 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, + 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, + 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, + 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, + 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, + 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, + 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, + 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, - 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, + 0x76, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, + 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, - 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, - 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, - 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, - 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, - 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, - 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a, - 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, - 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, - 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, - 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, - 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, - 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, + 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, + 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, + 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, + 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, + 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, + 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, + 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, + 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x22, 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, + 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, + 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, + 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, + 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, + 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, + 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, + 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, + 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, + 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, + 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, + 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, + 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, + 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, + 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, + 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, + 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, + 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, + 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, + 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, + 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, + 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, + 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, + 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, + 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, + 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, + 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, + 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, + 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, + 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, - 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, - 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, - 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, - 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, - 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, - 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, - 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, - 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, - 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, - 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, + 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, + 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2474,90 +3004,108 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41) var file_daemon_proto_goTypes = []interface{}{ - (LogLevel)(0), // 0: daemon.LogLevel - (*LoginRequest)(nil), // 1: daemon.LoginRequest - (*LoginResponse)(nil), // 2: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 5: daemon.UpRequest - (*UpResponse)(nil), // 6: daemon.UpResponse - (*StatusRequest)(nil), // 7: daemon.StatusRequest - (*StatusResponse)(nil), // 8: daemon.StatusResponse - (*DownRequest)(nil), // 9: daemon.DownRequest - (*DownResponse)(nil), // 10: daemon.DownResponse - (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse - (*PeerState)(nil), // 13: daemon.PeerState - (*LocalPeerState)(nil), // 14: daemon.LocalPeerState - (*SignalState)(nil), // 15: daemon.SignalState - (*ManagementState)(nil), // 16: daemon.ManagementState - (*RelayState)(nil), // 17: daemon.RelayState - (*NSGroupState)(nil), // 18: daemon.NSGroupState - (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse - (*IPList)(nil), // 24: daemon.IPList - (*Route)(nil), // 25: daemon.Route - (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse - nil, // 32: daemon.Route.ResolvedIPsEntry - (*durationpb.Duration)(nil), // 33: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp + (LogLevel)(0), // 0: daemon.LogLevel + (*LoginRequest)(nil), // 1: daemon.LoginRequest + (*LoginResponse)(nil), // 2: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 5: daemon.UpRequest + (*UpResponse)(nil), // 6: daemon.UpResponse + (*StatusRequest)(nil), // 7: daemon.StatusRequest + (*StatusResponse)(nil), // 8: daemon.StatusResponse + (*DownRequest)(nil), // 9: daemon.DownRequest + (*DownResponse)(nil), // 10: daemon.DownResponse + (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse + (*PeerState)(nil), // 13: daemon.PeerState + (*LocalPeerState)(nil), // 14: daemon.LocalPeerState + (*SignalState)(nil), // 15: daemon.SignalState + (*ManagementState)(nil), // 16: daemon.ManagementState + (*RelayState)(nil), // 17: daemon.RelayState + (*NSGroupState)(nil), // 18: daemon.NSGroupState + (*FullStatus)(nil), // 19: daemon.FullStatus + (*ListNetworksRequest)(nil), // 20: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 21: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 22: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 23: daemon.SelectNetworksResponse + (*IPList)(nil), // 24: daemon.IPList + (*Network)(nil), // 25: daemon.Network + (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse + (*State)(nil), // 32: daemon.State + (*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse + (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest + (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse + nil, // 41: daemon.Network.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 42: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 13, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList - 1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest - 26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse - 27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 28, // [28:40] is the sub-list for method output_type - 16, // [16:28] is the sub-list for method input_type - 16, // [16:16] is the sub-list for extension type_name - 16, // [16:16] is the sub-list for extension extendee - 0, // [0:16] is the sub-list for field type_name + 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State + 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 33, // [33:49] is the sub-list for method output_type + 17, // [17:33] is the sub-list for method input_type + 17, // [17:17] is the sub-list for extension type_name + 17, // [17:17] is the sub-list for extension extendee + 0, // [0:17] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -2795,7 +3343,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesRequest); i { + switch v := v.(*ListNetworksRequest); i { case 0: return &v.state case 1: @@ -2807,7 +3355,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesResponse); i { + switch v := v.(*ListNetworksResponse); i { case 0: return &v.state case 1: @@ -2819,7 +3367,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesRequest); i { + switch v := v.(*SelectNetworksRequest); i { case 0: return &v.state case 1: @@ -2831,7 +3379,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesResponse); i { + switch v := v.(*SelectNetworksResponse); i { case 0: return &v.state case 1: @@ -2855,7 +3403,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*Network); i { case 0: return &v.state case 1: @@ -2938,6 +3486,114 @@ func file_daemon_proto_init() { return nil } } + file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*State); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListStatesRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListStatesResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CleanStateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CleanStateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteStateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteStateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetNetworkMapPersistenceRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetNetworkMapPersistenceResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} type x struct{} @@ -2946,7 +3602,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 32, + NumMessages: 41, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 384bc0e62..ad3a4bc1a 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -28,14 +28,14 @@ service DaemonService { // GetConfig of the daemon. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} - // List available network routes - rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {} + // List available networks + rpc ListNetworks(ListNetworksRequest) returns (ListNetworksResponse) {} // Select specific routes - rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc SelectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // Deselect specific routes - rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // DebugBundle creates a debug bundle rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {} @@ -45,7 +45,20 @@ service DaemonService { // SetLogLevel sets the log level of the daemon rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {} -}; + + // List all states + rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {} + + // Clean specific state or all states + rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {} + + // Delete specific state or all states + rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {} + + // SetNetworkMapPersistence enables or disables network map persistence + rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} +} + message LoginRequest { // setupKey wiretrustee setup key. @@ -94,6 +107,11 @@ message LoginRequest { optional bool networkMonitor = 18; optional google.protobuf.Duration dnsRouteInterval = 19; + + optional bool disable_client_routes = 20; + optional bool disable_server_routes = 21; + optional bool disable_dns = 22; + optional bool disable_firewall = 23; } message LoginResponse { @@ -177,7 +195,7 @@ message PeerState { int64 bytesRx = 13; int64 bytesTx = 14; bool rosenpassEnabled = 15; - repeated string routes = 16; + repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; } @@ -190,7 +208,7 @@ message LocalPeerState { string fqdn = 4; bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; - repeated string routes = 7; + repeated string networks = 7; } // SignalState contains the latest state of a signal connection @@ -231,20 +249,20 @@ message FullStatus { repeated NSGroupState dns_servers = 6; } -message ListRoutesRequest { +message ListNetworksRequest { } -message ListRoutesResponse { - repeated Route routes = 1; +message ListNetworksResponse { + repeated Network routes = 1; } -message SelectRoutesRequest { - repeated string routeIDs = 1; +message SelectNetworksRequest { + repeated string networkIDs = 1; bool append = 2; bool all = 3; } -message SelectRoutesResponse { +message SelectNetworksResponse { } message IPList { @@ -252,9 +270,9 @@ message IPList { } -message Route { +message Network { string ID = 1; - string network = 2; + string range = 2; bool selected = 3; repeated string domains = 4; map resolvedIPs = 5; @@ -293,4 +311,46 @@ message SetLogLevelRequest { } message SetLogLevelResponse { -} \ No newline at end of file +} + +// State represents a daemon state entry +message State { + string name = 1; +} + +// ListStatesRequest is empty as it requires no parameters +message ListStatesRequest {} + +// ListStatesResponse contains a list of states +message ListStatesResponse { + repeated State states = 1; +} + +// CleanStateRequest for cleaning states +message CleanStateRequest { + string state_name = 1; + bool all = 2; +} + +// CleanStateResponse contains the result of the clean operation +message CleanStateResponse { + int32 cleaned_states = 1; +} + +// DeleteStateRequest for deleting states +message DeleteStateRequest { + string state_name = 1; + bool all = 2; +} + +// DeleteStateResponse contains the result of the delete operation +message DeleteStateResponse { + int32 deleted_states = 1; +} + + +message SetNetworkMapPersistenceRequest { + bool enabled = 1; +} + +message SetNetworkMapPersistenceResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e0bc117e5..39424aee9 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -31,18 +31,26 @@ type DaemonServiceClient interface { Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) - // List available network routes - ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) + // List available networks + ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) + // List all states + ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) + // Clean specific state or all states + CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) + // Delete specific state or all states + DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) + // SetNetworkMapPersistence enables or disables network map persistence + SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) } type daemonServiceClient struct { @@ -107,27 +115,27 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques return out, nil } -func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) { - out := new(ListRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...) +func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + out := new(ListNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...) +func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...) +func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -161,6 +169,42 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe return out, nil } +func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { + out := new(ListStatesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { + out := new(CleanStateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { + out := new(DeleteStateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { + out := new(SetNetworkMapPersistenceResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -178,18 +222,26 @@ type DaemonServiceServer interface { Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) - // List available network routes - ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) + // List available networks + ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) + // List all states + ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) + // Clean specific state or all states + CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) + // Delete specific state or all states + DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) + // SetNetworkMapPersistence enables or disables network map persistence + SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -215,14 +267,14 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") } -func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented") +func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") } -func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented") +func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") } -func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented") +func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") @@ -233,6 +285,18 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") } +func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented") +} +func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented") +} +func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented") +} +func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -354,56 +418,56 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } -func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ListRoutesRequest) +func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).ListRoutes(ctx, in) + return srv.(DaemonServiceServer).ListNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListRoutes", + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest)) + return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).SelectRoutes(ctx, in) + return srv.(DaemonServiceServer).SelectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectRoutes", + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, in) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectRoutes", + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } @@ -462,6 +526,78 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListStatesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).ListStates(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/ListStates", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CleanStateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).CleanState(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/CleanState", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteStateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).DeleteState(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/DeleteState", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetNetworkMapPersistenceRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -494,16 +630,16 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ Handler: _DaemonService_GetConfig_Handler, }, { - MethodName: "ListRoutes", - Handler: _DaemonService_ListRoutes_Handler, + MethodName: "ListNetworks", + Handler: _DaemonService_ListNetworks_Handler, }, { - MethodName: "SelectRoutes", - Handler: _DaemonService_SelectRoutes_Handler, + MethodName: "SelectNetworks", + Handler: _DaemonService_SelectNetworks_Handler, }, { - MethodName: "DeselectRoutes", - Handler: _DaemonService_DeselectRoutes_Handler, + MethodName: "DeselectNetworks", + Handler: _DaemonService_DeselectNetworks_Handler, }, { MethodName: "DebugBundle", @@ -517,6 +653,22 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetLogLevel", Handler: _DaemonService_SetLogLevel_Handler, }, + { + MethodName: "ListStates", + Handler: _DaemonService_ListStates_Handler, + }, + { + MethodName: "CleanState", + Handler: _DaemonService_CleanState_Handler, + }, + { + MethodName: "DeleteState", + Handler: _DaemonService_DeleteState_Handler, + }, + { + MethodName: "SetNetworkMapPersistence", + Handler: _DaemonService_SetNetworkMapPersistence_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "daemon.proto", diff --git a/client/server/debug.go b/client/server/debug.go index 5ed43293b..3c4967b4e 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -5,32 +5,46 @@ package server import ( "archive/zip" "bufio" + "bytes" "context" + "encoding/json" + "errors" "fmt" "io" + "io/fs" "net" "net/netip" "os" + "path/filepath" "sort" "strings" "time" log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/proto" + mgmProto "github.com/netbirdio/netbird/management/proto" ) const readmeContent = `Netbird debug bundle This debug bundle contains the following files: status.txt: Anonymized status information of the NetBird client. -client.log: Most recent, anonymized log file of the NetBird client. +client.log: Most recent, anonymized client log file of the NetBird client. +netbird.err: Most recent, anonymized stderr log file of the NetBird client. +netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. +iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. +network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. +state.json: Anonymized client state dump containing netbird states. Anonymization Process @@ -50,8 +64,32 @@ Domains All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. Reoccuring domain names are replaced with the same anonymized domain. +Network Map +The network_map.json file contains the following anonymized information: +- Peer configurations (addresses, FQDNs, DNS settings) +- Remote and offline peer information (allowed IPs, FQDNs) +- Routes (network ranges, associated domains) +- DNS configuration (nameservers, domains, custom zones) +- Firewall rules (peer IPs, source/destination ranges) + +SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above. + +State File +The state.json file contains anonymized internal state information of the NetBird client, including: +- DNS settings and configuration +- Firewall rules +- Exclusion routes +- Route selection +- Other internal states that may be present + +The state file follows the same anonymization rules as other files: +- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure +- Domain names are consistently anonymized +- Technical identifiers and non-sensitive data remain unchanged + Routes For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. + Network Interfaces The interfaces.txt file contains information about network interfaces, including: - Interface name @@ -70,17 +108,37 @@ The config.txt file contains anonymized configuration information of the NetBird - CustomDNSAddress Other non-sensitive configuration options are included without anonymization. + +Firewall Rules (Linux only) +The bundle includes two separate firewall rule files: + +iptables.txt: +- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- Includes all tables (filter, nat, mangle, raw, security) +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged + +nftables.txt: +- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Includes rule handle numbers and packet counters +- All tables, chains, and rules are included +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged ` +const ( + clientLogFile = "client.log" + errorLogFile = "netbird.err" + stdoutLogFile = "netbird.out" +) + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() defer s.mutex.Unlock() - if s.logFile == "console" { - return nil, fmt.Errorf("log file is set to console, cannot create debug bundle") - } - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") if err != nil { return nil, fmt.Errorf("create zip file: %w", err) @@ -119,21 +177,25 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques seedFromStatus(anonymizer, &status) if err := s.addConfig(req, anonymizer, archive); err != nil { - return fmt.Errorf("add config: %w", err) + log.Errorf("Failed to add config to debug bundle: %v", err) } if req.GetSystemInfo() { - if err := s.addRoutes(req, anonymizer, archive); err != nil { - return fmt.Errorf("add routes: %w", err) - } - - if err := s.addInterfaces(req, anonymizer, archive); err != nil { - return fmt.Errorf("add interfaces: %w", err) - } + s.addSystemInfo(req, anonymizer, archive) } - if err := s.addLogfile(req, anonymizer, archive); err != nil { - return fmt.Errorf("add log file: %w", err) + if err := s.addNetworkMap(req, anonymizer, archive); err != nil { + return fmt.Errorf("add network map: %w", err) + } + + if err := s.addStateFile(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add state file to debug bundle: %v", err) + } + + if s.logFile != "console" { + if err := s.addLogfile(req, anonymizer, archive); err != nil { + return fmt.Errorf("add log file: %w", err) + } } if err := archive.Close(); err != nil { @@ -142,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques return nil } +func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) { + if err := s.addRoutes(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add routes to debug bundle: %v", err) + } + + if err := s.addInterfaces(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add interfaces to debug bundle: %v", err) + } + + if err := s.addFirewallRules(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add firewall rules to debug bundle: %v", err) + } +} + func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error { if req.GetAnonymize() { readmeReader := strings.NewReader(readmeContent) @@ -220,15 +296,16 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) { } func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { - if routes, err := systemops.GetRoutesFromTable(); err != nil { - log.Errorf("Failed to get routes: %v", err) - } else { - // TODO: get routes including nexthop - routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer) - routesReader := strings.NewReader(routesContent) - if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { - return fmt.Errorf("add routes file to zip: %w", err) - } + routes, err := systemops.GetRoutesFromTable() + if err != nil { + return fmt.Errorf("get routes: %w", err) + } + + // TODO: get routes including nexthop + routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer) + routesReader := strings.NewReader(routesContent) + if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { + return fmt.Errorf("add routes file to zip: %w", err) } return nil } @@ -248,14 +325,106 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym return nil } -func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { - logFile, err := os.Open(s.logFile) +func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + networkMap, err := s.getLatestNetworkMap() if err != nil { - return fmt.Errorf("open log file: %w", err) + // Skip if network map is not available, but log it + log.Debugf("skipping empty network map in debug bundle: %v", err) + return nil + } + + if req.GetAnonymize() { + if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil { + return fmt.Errorf("anonymize network map: %w", err) + } + } + + options := protojson.MarshalOptions{ + EmitUnpopulated: true, + UseProtoNames: true, + Indent: " ", + AllowPartial: true, + } + + jsonBytes, err := options.Marshal(networkMap) + if err != nil { + return fmt.Errorf("generate json: %w", err) + } + + if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil { + return fmt.Errorf("add network map to zip: %w", err) + } + + return nil +} + +func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + if req.GetAnonymize() { + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + return fmt.Errorf("unmarshal states: %w", err) + } + + if err := anonymizeStateFile(&rawStates, anonymizer); err != nil { + return fmt.Errorf("anonymize state file: %w", err) + } + + bs, err := json.MarshalIndent(rawStates, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + data = bs + } + + if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil { + return fmt.Errorf("add state file to zip: %w", err) + } + + return nil +} + +func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logDir := filepath.Dir(s.logFile) + + if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil { + return fmt.Errorf("add client log file to zip: %w", err) + } + + errLogPath := filepath.Join(logDir, errorLogFile) + if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", errorLogFile, err) + } + + stdoutLogPath := filepath.Join(logDir, stdoutLogFile) + if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err) + } + + return nil +} + +// addSingleLogfile adds a single log file to the archive +func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logFile, err := os.Open(logPath) + if err != nil { + return fmt.Errorf("open log file %s: %w", targetName, err) } defer func() { if err := logFile.Close(); err != nil { - log.Errorf("Failed to close original log file: %v", err) + log.Errorf("Failed to close log file %s: %v", targetName, err) } }() @@ -264,45 +433,55 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize var writer *io.PipeWriter logReader, writer = io.Pipe() - go s.anonymize(logFile, writer, anonymizer) + go anonymizeLog(logFile, writer, anonymizer) } else { logReader = logFile } - if err := addFileToZip(archive, logReader, "client.log"); err != nil { - return fmt.Errorf("add log file to zip: %w", err) + + if err := addFileToZip(archive, logReader, targetName); err != nil { + return fmt.Errorf("add %s to zip: %w", targetName, err) } return nil } -func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { - defer func() { - // always nil - _ = writer.Close() - }() +// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled +func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) { + if s.connectClient == nil { + return nil, errors.New("connect client is not initialized") + } - scanner := bufio.NewScanner(reader) - for scanner.Scan() { - line := anonymizer.AnonymizeString(scanner.Text()) - if _, err := writer.Write([]byte(line + "\n")); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) - return - } + engine := s.connectClient.Engine() + if engine == nil { + return nil, errors.New("engine is not initialized") } - if err := scanner.Err(); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) - return + + networkMap, err := engine.GetLatestNetworkMap() + if err != nil { + return nil, fmt.Errorf("get latest network map: %w", err) } + + if networkMap == nil { + return nil, errors.New("network map is not available") + } + + return networkMap, nil } // GetLogLevel gets the current logging level for the server. func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + level := ParseLogLevel(log.GetLevel().String()) return &proto.GetLogLevelResponse{Level: level}, nil } // SetLogLevel sets the logging level for the server. func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + level, err := log.ParseLevel(req.Level.String()) if err != nil { return nil, fmt.Errorf("invalid log level: %w", err) @@ -313,6 +492,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( return &proto.SetLogLevelResponse{}, nil } +// SetNetworkMapPersistence sets the network map persistence for the server. +func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + enabled := req.GetEnabled() + s.persistNetworkMap = enabled + if s.connectClient != nil { + s.connectClient.SetNetworkMapPersistence(enabled) + } + + return &proto.SetNetworkMapPersistenceResponse{}, nil +} + func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { header := &zip.FileHeader{ Name: filename, @@ -458,6 +651,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an return builder.String() } +func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { + defer func() { + // always nil + _ = writer.Close() + }() + + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := anonymizer.AnonymizeString(scanner.Text()) + if _, err := writer.Write([]byte(line + "\n")); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) + return + } + } + if err := scanner.Err(); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) + return + } +} + func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { anonymizedIPs := make([]string, len(ips)) for i, ip := range ips { @@ -484,3 +697,248 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s } return anonymizedIPs } + +func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error { + if networkMap.PeerConfig != nil { + anonymizePeerConfig(networkMap.PeerConfig, anonymizer) + } + + for _, peer := range networkMap.RemotePeers { + anonymizeRemotePeer(peer, anonymizer) + } + + for _, peer := range networkMap.OfflinePeers { + anonymizeRemotePeer(peer, anonymizer) + } + + for _, r := range networkMap.Routes { + anonymizeRoute(r, anonymizer) + } + + if networkMap.DNSConfig != nil { + anonymizeDNSConfig(networkMap.DNSConfig, anonymizer) + } + + for _, rule := range networkMap.FirewallRules { + anonymizeFirewallRule(rule, anonymizer) + } + + for _, rule := range networkMap.RoutesFirewallRules { + anonymizeRouteFirewallRule(rule, anonymizer) + } + + return nil +} + +func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) { + if config == nil { + return + } + + if addr, err := netip.ParseAddr(config.Address); err == nil { + config.Address = anonymizer.AnonymizeIP(addr).String() + } + + if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 { + config.SshConfig.SshPubKey = []byte("ssh-placeholder-key") + } + + config.Dns = anonymizer.AnonymizeString(config.Dns) + config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn) +} + +func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) { + if peer == nil { + return + } + + for i, ip := range peer.AllowedIps { + // Try to parse as prefix first (CIDR) + if prefix, err := netip.ParsePrefix(ip); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } else if addr, err := netip.ParseAddr(ip); err == nil { + peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String() + } + } + + peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn) + + if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 { + peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key") + } +} + +func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) { + if route == nil { + return + } + + if prefix, err := netip.ParsePrefix(route.Network); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + + for i, domain := range route.Domains { + route.Domains[i] = anonymizer.AnonymizeDomain(domain) + } + + route.NetID = anonymizer.AnonymizeString(route.NetID) +} + +func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) { + if config == nil { + return + } + + anonymizeNameServerGroups(config.NameServerGroups, anonymizer) + anonymizeCustomZones(config.CustomZones, anonymizer) +} + +func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) { + for _, group := range groups { + anonymizeServers(group.NameServers, anonymizer) + anonymizeDomains(group.Domains, anonymizer) + } +} + +func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) { + for _, server := range servers { + if addr, err := netip.ParseAddr(server.IP); err == nil { + server.IP = anonymizer.AnonymizeIP(addr).String() + } + } +} + +func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) { + for i, domain := range domains { + domains[i] = anonymizer.AnonymizeDomain(domain) + } +} + +func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) { + for _, zone := range zones { + zone.Domain = anonymizer.AnonymizeDomain(zone.Domain) + anonymizeRecords(zone.Records, anonymizer) + } +} + +func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) { + for _, record := range records { + record.Name = anonymizer.AnonymizeDomain(record.Name) + anonymizeRData(record, anonymizer) + } +} + +func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) { + switch record.Type { + case 1, 28: // A or AAAA record + if addr, err := netip.ParseAddr(record.RData); err == nil { + record.RData = anonymizer.AnonymizeIP(addr).String() + } + default: + record.RData = anonymizer.AnonymizeString(record.RData) + } +} + +func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) { + if rule == nil { + return + } + + if addr, err := netip.ParseAddr(rule.PeerIP); err == nil { + rule.PeerIP = anonymizer.AnonymizeIP(addr).String() + } +} + +func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) { + if rule == nil { + return + } + + for i, sourceRange := range rule.SourceRanges { + if prefix, err := netip.ParsePrefix(sourceRange); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + } + + if prefix, err := netip.ParsePrefix(rule.Destination); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } +} + +func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error { + for name, rawState := range *rawStates { + if string(rawState) == "null" { + continue + } + + var state map[string]any + if err := json.Unmarshal(rawState, &state); err != nil { + return fmt.Errorf("unmarshal state %s: %w", name, err) + } + + state = anonymizeValue(state, anonymizer).(map[string]any) + + bs, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal state %s: %w", name, err) + } + + (*rawStates)[name] = bs + } + + return nil +} + +func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any { + switch v := value.(type) { + case string: + return anonymizeString(v, anonymizer) + case map[string]any: + return anonymizeMap(v, anonymizer) + case []any: + return anonymizeSlice(v, anonymizer) + } + return value +} + +func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(v); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(v); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return anonymizer.AnonymizeString(v) +} + +func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any { + result := make(map[string]any, len(v)) + for key, val := range v { + newKey := anonymizeMapKey(key, anonymizer) + result[newKey] = anonymizeValue(val, anonymizer) + } + return result +} + +func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(key); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(key); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return key +} + +func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any { + for i, val := range v { + v[i] = anonymizeValue(val, anonymizer) + } + return v +} diff --git a/client/server/debug_linux.go b/client/server/debug_linux.go new file mode 100644 index 000000000..60bc40561 --- /dev/null +++ b/client/server/debug_linux.go @@ -0,0 +1,693 @@ +//go:build linux && !android + +package server + +import ( + "archive/zip" + "bytes" + "encoding/binary" + "fmt" + "os/exec" + "sort" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// addFirewallRules collects and adds firewall rules to the archive +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + log.Info("Collecting firewall rules") + // Collect and add iptables rules + iptablesRules, err := collectIPTablesRules() + if err != nil { + log.Warnf("Failed to collect iptables rules: %v", err) + } else { + if req.GetAnonymize() { + iptablesRules = anonymizer.AnonymizeString(iptablesRules) + } + if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { + log.Warnf("Failed to add iptables rules to bundle: %v", err) + } + } + + // Collect and add nftables rules + nftablesRules, err := collectNFTablesRules() + if err != nil { + log.Warnf("Failed to collect nftables rules: %v", err) + } else { + if req.GetAnonymize() { + nftablesRules = anonymizer.AnonymizeString(nftablesRules) + } + if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { + log.Warnf("Failed to add nftables rules to bundle: %v", err) + } + } + + return nil +} + +// collectIPTablesRules collects rules using both iptables-save and verbose listing +func collectIPTablesRules() (string, error) { + var builder strings.Builder + + // First try using iptables-save + saveOutput, err := collectIPTablesSave() + if err != nil { + log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) + } else { + builder.WriteString("=== iptables-save output ===\n") + builder.WriteString(saveOutput) + builder.WriteString("\n") + } + + // Then get verbose statistics for each table + builder.WriteString("=== iptables -v -n -L output ===\n") + + // Get list of tables + tables := []string{"filter", "nat", "mangle", "raw", "security"} + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("*%s\n", table)) + + // Get verbose statistics for the entire table + stats, err := getTableStatistics(table) + if err != nil { + log.Warnf("Failed to get statistics for table %s: %v", table, err) + continue + } + builder.WriteString(stats) + builder.WriteString("\n") + } + + return builder.String(), nil +} + +// collectIPTablesSave uses iptables-save to get rule definitions +func collectIPTablesSave() (string, error) { + cmd := exec.Command("iptables-save") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no iptables rules found") + } + + return rules, nil +} + +// getTableStatistics gets verbose statistics for an entire table using iptables command +func getTableStatistics(table string) (string, error) { + cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + } + + return stdout.String(), nil +} + +// collectNFTablesRules attempts to collect nftables rules using either nft command or netlink +func collectNFTablesRules() (string, error) { + // First try using nft command + rules, err := collectNFTablesFromCommand() + if err != nil { + log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err) + // Fall back to netlink + rules, err = collectNFTablesFromNetlink() + if err != nil { + return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err) + } + } + return rules, nil +} + +// collectNFTablesFromCommand attempts to collect rules using nft command +func collectNFTablesFromCommand() (string, error) { + cmd := exec.Command("nft", "-a", "list", "ruleset") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute nft list ruleset: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no nftables rules found") + } + + return rules, nil +} + +// collectNFTablesFromNetlink collects rules using netlink library +func collectNFTablesFromNetlink() (string, error) { + conn, err := nftables.New() + if err != nil { + return "", fmt.Errorf("create nftables connection: %w", err) + } + + tables, err := conn.ListTables() + if err != nil { + return "", fmt.Errorf("list tables: %w", err) + } + + sortTables(tables) + return formatTables(conn, tables), nil +} + +func formatTables(conn *nftables.Conn, tables []*nftables.Table) string { + var builder strings.Builder + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name)) + + chains, err := getAndSortTableChains(conn, table) + if err != nil { + log.Warnf("Failed to list chains for table %s: %v", table.Name, err) + continue + } + + // Format chains + for _, chain := range chains { + formatChain(conn, table, chain, &builder) + } + + // Format sets + if sets, err := conn.GetSets(table); err != nil { + log.Warnf("Failed to get sets for table %s: %v", table.Name, err) + } else if len(sets) > 0 { + builder.WriteString("\n") + for _, set := range sets { + builder.WriteString(formatSet(conn, set)) + } + } + + builder.WriteString("}\n") + } + + return builder.String() +} + +func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := conn.ListChains() + if err != nil { + return nil, err + } + + var tableChains []*nftables.Chain + for _, chain := range chains { + if chain.Table.Name == table.Name && chain.Table.Family == table.Family { + tableChains = append(tableChains, chain) + } + } + + sort.Slice(tableChains, func(i, j int) bool { + return tableChains[i].Name < tableChains[j].Name + }) + + return tableChains, nil +} + +func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) { + builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) + + if chain.Type != "" { + var policy string + if chain.Policy != nil { + policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) + } + builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", + formatChainType(chain.Type), + formatChainHook(chain.Hooknum), + chain.Priority, + policy)) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) + } else { + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + for _, rule := range rules { + builder.WriteString(formatRule(rule)) + } + } + + builder.WriteString("\t}\n") +} + +func sortTables(tables []*nftables.Table) { + sort.Slice(tables, func(i, j int) bool { + if tables[i].Family != tables[j].Family { + return tables[i].Family < tables[j].Family + } + return tables[i].Name < tables[j].Name + }) +} + +func formatFamily(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + case nftables.TableFamilyARP: + return "arp" + case nftables.TableFamilyBridge: + return "bridge" + case nftables.TableFamilyNetdev: + return "netdev" + default: + return fmt.Sprintf("family-%d", family) + } +} + +func formatChainType(typ nftables.ChainType) string { + switch typ { + case nftables.ChainTypeFilter: + return "filter" + case nftables.ChainTypeNAT: + return "nat" + case nftables.ChainTypeRoute: + return "route" + default: + return fmt.Sprintf("type-%s", typ) + } +} + +func formatChainHook(hook *nftables.ChainHook) string { + if hook == nil { + return "none" + } + switch *hook { + case *nftables.ChainHookPrerouting: + return "prerouting" + case *nftables.ChainHookInput: + return "input" + case *nftables.ChainHookForward: + return "forward" + case *nftables.ChainHookOutput: + return "output" + case *nftables.ChainHookPostrouting: + return "postrouting" + default: + return fmt.Sprintf("hook-%d", *hook) + } +} + +func formatPolicy(policy nftables.ChainPolicy) string { + switch policy { + case nftables.ChainPolicyDrop: + return "drop" + case nftables.ChainPolicyAccept: + return "accept" + default: + return fmt.Sprintf("policy-%d", policy) + } +} + +func formatRule(rule *nftables.Rule) string { + var builder strings.Builder + builder.WriteString("\t\t") + + for i := 0; i < len(rule.Exprs); i++ { + if i > 0 { + builder.WriteString(" ") + } + i = formatExprSequence(&builder, rule.Exprs, i) + } + + builder.WriteString("\n") + return builder.String() +} + +func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { + curr := exprs[i] + + // Handle Meta + Cmp sequence + if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { + builder.WriteString(formatted) + return i + 1 + } + } + } + + // Handle Payload + Cmp sequence + if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + builder.WriteString(formatPayloadWithCmp(payload, cmp)) + return i + 1 + } + } + + builder.WriteString(formatExpr(curr)) + return i +} + +func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { + switch meta.Key { + case expr.MetaKeyIIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyOIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyMARK: + if len(cmp.Data) == 4 { + val := binary.BigEndian.Uint32(cmp.Data) + return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val) + } + } + return "" +} + +func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { + if p.Base == expr.PayloadBaseNetworkHeader { + switch p.Offset { + case 12: // Source IP + if p.Len == 4 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + case 16: // Destination IP + if p.Len == 4 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + } + } + return fmt.Sprintf("%d reg%d [%d:%d] %s %v", + p.Base, p.DestRegister, p.Offset, p.Len, + formatCmpOp(cmp.Op), cmp.Data) +} + +func formatIPBytes(data []byte) string { + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } else if len(data) == 2 { + return fmt.Sprintf("%d.%d.0.0/16", data[0], data[1]) + } + return fmt.Sprintf("%v", data) +} + +func formatCmpOp(op expr.CmpOp) string { + switch op { + case expr.CmpOpEq: + return "==" + case expr.CmpOpNeq: + return "!=" + case expr.CmpOpLt: + return "<" + case expr.CmpOpLte: + return "<=" + case expr.CmpOpGt: + return ">" + case expr.CmpOpGte: + return ">=" + default: + return fmt.Sprintf("op-%d", op) + } +} + +// formatExpr formats an expression in nft-like syntax +func formatExpr(exp expr.Any) string { + switch e := exp.(type) { + case *expr.Meta: + return formatMeta(e) + case *expr.Cmp: + return formatCmp(e) + case *expr.Payload: + return formatPayload(e) + case *expr.Verdict: + return formatVerdict(e) + case *expr.Counter: + return fmt.Sprintf("counter packets %d bytes %d", e.Packets, e.Bytes) + case *expr.Masq: + return "masquerade" + case *expr.NAT: + return formatNat(e) + case *expr.Match: + return formatMatch(e) + case *expr.Queue: + return fmt.Sprintf("queue num %d", e.Num) + case *expr.Lookup: + return fmt.Sprintf("@%s", e.SetName) + case *expr.Bitwise: + return formatBitwise(e) + case *expr.Fib: + return formatFib(e) + case *expr.Target: + return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets + case *expr.Immediate: + if e.Register == 1 { + return formatImmediateData(e.Data) + } + return fmt.Sprintf("immediate %v", e.Data) + default: + return fmt.Sprintf("<%T>", exp) + } +} + +func formatImmediateData(data []byte) string { + // For IP addresses (4 bytes) + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } + return fmt.Sprintf("%v", data) +} + +func formatMeta(e *expr.Meta) string { + // Handle source register case first (meta mark set) + if e.SourceRegister { + return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register) + } + + // For interface names, handle register load operation + switch e.Key { + case expr.MetaKeyIIFNAME, + expr.MetaKeyOIFNAME, + expr.MetaKeyBRIIIFNAME, + expr.MetaKeyBRIOIFNAME: + // Simply the key name with no register reference + return formatMetaKey(e.Key) + + case expr.MetaKeyMARK: + // For mark operations, we want just "mark" + return "mark" + } + + // For other meta keys, show as loading into register + return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register) +} + +func formatMetaKey(key expr.MetaKey) string { + switch key { + case expr.MetaKeyLEN: + return "length" + case expr.MetaKeyPROTOCOL: + return "protocol" + case expr.MetaKeyPRIORITY: + return "priority" + case expr.MetaKeyMARK: + return "mark" + case expr.MetaKeyIIF: + return "iif" + case expr.MetaKeyOIF: + return "oif" + case expr.MetaKeyIIFNAME: + return "iifname" + case expr.MetaKeyOIFNAME: + return "oifname" + case expr.MetaKeyIIFTYPE: + return "iiftype" + case expr.MetaKeyOIFTYPE: + return "oiftype" + case expr.MetaKeySKUID: + return "skuid" + case expr.MetaKeySKGID: + return "skgid" + case expr.MetaKeyNFTRACE: + return "nftrace" + case expr.MetaKeyRTCLASSID: + return "rtclassid" + case expr.MetaKeySECMARK: + return "secmark" + case expr.MetaKeyNFPROTO: + return "nfproto" + case expr.MetaKeyL4PROTO: + return "l4proto" + case expr.MetaKeyBRIIIFNAME: + return "briifname" + case expr.MetaKeyBRIOIFNAME: + return "broifname" + case expr.MetaKeyPKTTYPE: + return "pkttype" + case expr.MetaKeyCPU: + return "cpu" + case expr.MetaKeyIIFGROUP: + return "iifgroup" + case expr.MetaKeyOIFGROUP: + return "oifgroup" + case expr.MetaKeyCGROUP: + return "cgroup" + case expr.MetaKeyPRANDOM: + return "prandom" + default: + return fmt.Sprintf("meta-%d", key) + } +} + +func formatCmp(e *expr.Cmp) string { + ops := map[expr.CmpOp]string{ + expr.CmpOpEq: "==", + expr.CmpOpNeq: "!=", + expr.CmpOpLt: "<", + expr.CmpOpLte: "<=", + expr.CmpOpGt: ">", + expr.CmpOpGte: ">=", + } + return fmt.Sprintf("%s %v", ops[e.Op], e.Data) +} + +func formatPayload(e *expr.Payload) string { + var proto string + switch e.Base { + case expr.PayloadBaseNetworkHeader: + proto = "ip" + case expr.PayloadBaseTransportHeader: + proto = "tcp" + default: + proto = fmt.Sprintf("payload-%d", e.Base) + } + return fmt.Sprintf("%s reg%d [%d:%d]", proto, e.DestRegister, e.Offset, e.Len) +} + +func formatVerdict(e *expr.Verdict) string { + switch e.Kind { + case expr.VerdictAccept: + return "accept" + case expr.VerdictDrop: + return "drop" + case expr.VerdictJump: + return fmt.Sprintf("jump %s", e.Chain) + case expr.VerdictGoto: + return fmt.Sprintf("goto %s", e.Chain) + case expr.VerdictReturn: + return "return" + default: + return fmt.Sprintf("verdict-%d", e.Kind) + } +} + +func formatNat(e *expr.NAT) string { + switch e.Type { + case expr.NATTypeSourceNAT: + return "snat" + case expr.NATTypeDestNAT: + return "dnat" + default: + return fmt.Sprintf("nat-%d", e.Type) + } +} + +func formatMatch(e *expr.Match) string { + return fmt.Sprintf("match %s rev %d", e.Name, e.Rev) +} + +func formatBitwise(e *expr.Bitwise) string { + return fmt.Sprintf("bitwise reg%d = reg%d & %v ^ %v", + e.DestRegister, e.SourceRegister, e.Mask, e.Xor) +} + +func formatFib(e *expr.Fib) string { + var flags []string + if e.FlagSADDR { + flags = append(flags, "saddr") + } + if e.FlagDADDR { + flags = append(flags, "daddr") + } + if e.FlagMARK { + flags = append(flags, "mark") + } + if e.FlagIIF { + flags = append(flags, "iif") + } + if e.FlagOIF { + flags = append(flags, "oif") + } + if e.ResultADDRTYPE { + flags = append(flags, "type") + } + return fmt.Sprintf("fib reg%d %s", e.Register, strings.Join(flags, ",")) +} + +func formatSet(conn *nftables.Conn, set *nftables.Set) string { + var builder strings.Builder + builder.WriteString(fmt.Sprintf("\tset %s {\n", set.Name)) + builder.WriteString(fmt.Sprintf("\t\ttype %s\n", formatSetKeyType(set.KeyType))) + if set.ID > 0 { + builder.WriteString(fmt.Sprintf("\t\t# handle %d\n", set.ID)) + } + + elements, err := conn.GetSetElements(set) + if err != nil { + log.Warnf("Failed to get elements for set %s: %v", set.Name, err) + } else if len(elements) > 0 { + builder.WriteString("\t\telements = {") + for i, elem := range elements { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(fmt.Sprintf("%v", elem.Key)) + } + builder.WriteString("}\n") + } + + builder.WriteString("\t}\n") + return builder.String() +} + +func formatSetKeyType(keyType nftables.SetDatatype) string { + switch keyType { + case nftables.TypeInvalid: + return "invalid" + case nftables.TypeIPAddr: + return "ipv4_addr" + case nftables.TypeIP6Addr: + return "ipv6_addr" + case nftables.TypeEtherAddr: + return "ether_addr" + case nftables.TypeInetProto: + return "inet_proto" + case nftables.TypeInetService: + return "inet_service" + case nftables.TypeMark: + return "mark" + default: + return fmt.Sprintf("type-%v", keyType) + } +} diff --git a/client/server/debug_nonlinux.go b/client/server/debug_nonlinux.go new file mode 100644 index 000000000..c54ac9b6e --- /dev/null +++ b/client/server/debug_nonlinux.go @@ -0,0 +1,15 @@ +//go:build !linux || android + +package server + +import ( + "archive/zip" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// collectFirewallRules returns nothing on non-linux systems +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + return nil +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go new file mode 100644 index 000000000..ebd0bffbc --- /dev/null +++ b/client/server/debug_test.go @@ -0,0 +1,543 @@ +package server + +import ( + "encoding/json" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/anonymize" + mgmProto "github.com/netbirdio/netbird/management/proto" +) + +func TestAnonymizeStateFile(t *testing.T) { + testState := map[string]json.RawMessage{ + "null_state": json.RawMessage("null"), + "test_state": mustMarshal(map[string]any{ + // Test simple fields + "public_ip": "203.0.113.1", + "private_ip": "192.168.1.1", + "protected_ip": "100.64.0.1", + "well_known_ip": "8.8.8.8", + "ipv6_addr": "2001:db8::1", + "private_ipv6": "fd00::1", + "domain": "test.example.com", + "uri": "stun:stun.example.com:3478", + "uri_with_ip": "turn:203.0.113.1:3478", + "netbird_domain": "device.netbird.cloud", + + // Test CIDR ranges + "public_cidr": "203.0.113.0/24", + "private_cidr": "192.168.0.0/16", + "protected_cidr": "100.64.0.0/10", + "ipv6_cidr": "2001:db8::/32", + "private_ipv6_cidr": "fd00::/8", + + // Test nested structures + "nested": map[string]any{ + "ip": "203.0.113.2", + "domain": "nested.example.com", + "more_nest": map[string]any{ + "ip": "203.0.113.3", + "domain": "deep.example.com", + }, + }, + + // Test arrays + "string_array": []any{ + "203.0.113.4", + "test1.example.com", + "test2.example.com", + }, + "object_array": []any{ + map[string]any{ + "ip": "203.0.113.5", + "domain": "array1.example.com", + }, + map[string]any{ + "ip": "203.0.113.6", + "domain": "array2.example.com", + }, + }, + + // Test multiple occurrences of same value + "duplicate_ip": "203.0.113.1", // Same as public_ip + "duplicate_domain": "test.example.com", // Same as domain + + // Test URIs with various schemes + "stun_uri": "stun:stun.example.com:3478", + "turns_uri": "turns:turns.example.com:5349", + "http_uri": "http://web.example.com:80", + "https_uri": "https://secure.example.com:443", + + // Test strings that might look like IPs but aren't + "not_ip": "300.300.300.300", + "partial_ip": "192.168", + "ip_like_string": "1234.5678", + + // Test mixed content strings + "mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80", + + // Test empty and special values + "empty_string": "", + "null_value": nil, + "numeric_value": 42, + "boolean_value": true, + }), + "route_state": mustMarshal(map[string]any{ + "routes": []any{ + map[string]any{ + "network": "203.0.113.0/24", + "gateway": "203.0.113.1", + "domains": []any{ + "route1.example.com", + "route2.example.com", + }, + }, + map[string]any{ + "network": "2001:db8::/32", + "gateway": "2001:db8::1", + "domains": []any{ + "route3.example.com", + "route4.example.com", + }, + }, + }, + // Test map with IP/CIDR keys + "refCountMap": map[string]any{ + "203.0.113.1/32": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "2001:db8::1/128": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "fe80::1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "10.0.0.1/32": map[string]any{ // private IP should remain unchanged + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + }, + }, + }, + }), + } + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Pre-seed the domains we need to verify in the test assertions + anonymizer.AnonymizeDomain("test.example.com") + anonymizer.AnonymizeDomain("nested.example.com") + anonymizer.AnonymizeDomain("deep.example.com") + anonymizer.AnonymizeDomain("array1.example.com") + + err := anonymizeStateFile(&testState, anonymizer) + require.NoError(t, err) + + // Helper function to unmarshal and get nested values + var state map[string]any + err = json.Unmarshal(testState["test_state"], &state) + require.NoError(t, err) + + // Test null state remains unchanged + require.Equal(t, "null", string(testState["null_state"])) + + // Basic assertions + assert.NotEqual(t, "203.0.113.1", state["public_ip"]) + assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged + assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged + assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged + assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) + assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "test.example.com", state["domain"]) + assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) + assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged + + // CIDR ranges + assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"]) + assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved + assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged + assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged + assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"]) + assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved + + // Nested structures + nested := state["nested"].(map[string]any) + assert.NotEqual(t, "203.0.113.2", nested["ip"]) + assert.NotEqual(t, "nested.example.com", nested["domain"]) + moreNest := nested["more_nest"].(map[string]any) + assert.NotEqual(t, "203.0.113.3", moreNest["ip"]) + assert.NotEqual(t, "deep.example.com", moreNest["domain"]) + + // Arrays + strArray := state["string_array"].([]any) + assert.NotEqual(t, "203.0.113.4", strArray[0]) + assert.NotEqual(t, "test1.example.com", strArray[1]) + assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain")) + + objArray := state["object_array"].([]any) + firstObj := objArray[0].(map[string]any) + assert.NotEqual(t, "203.0.113.5", firstObj["ip"]) + assert.NotEqual(t, "array1.example.com", firstObj["domain"]) + + // Duplicate values should be anonymized consistently + assert.Equal(t, state["public_ip"], state["duplicate_ip"]) + assert.Equal(t, state["domain"], state["duplicate_domain"]) + + // URIs + assert.NotContains(t, state["stun_uri"], "stun.example.com") + assert.NotContains(t, state["turns_uri"], "turns.example.com") + assert.NotContains(t, state["http_uri"], "web.example.com") + assert.NotContains(t, state["https_uri"], "secure.example.com") + + // Non-IP strings should remain unchanged + assert.Equal(t, "300.300.300.300", state["not_ip"]) + assert.Equal(t, "192.168", state["partial_ip"]) + assert.Equal(t, "1234.5678", state["ip_like_string"]) + + // Mixed content should have IPs and domains replaced + mixedContent := state["mixed_content"].(string) + assert.NotContains(t, mixedContent, "203.0.113.1") + assert.NotContains(t, mixedContent, "test.example.com") + assert.Contains(t, mixedContent, "Server at ") + assert.Contains(t, mixedContent, " on port 80") + + // Special values should remain unchanged + assert.Equal(t, "", state["empty_string"]) + assert.Nil(t, state["null_value"]) + assert.Equal(t, float64(42), state["numeric_value"]) + assert.Equal(t, true, state["boolean_value"]) + + // Check route state + var routeState map[string]any + err = json.Unmarshal(testState["route_state"], &routeState) + require.NoError(t, err) + + routes := routeState["routes"].([]any) + route1 := routes[0].(map[string]any) + assert.NotEqual(t, "203.0.113.0/24", route1["network"]) + assert.Contains(t, route1["network"], "/24") + assert.NotEqual(t, "203.0.113.1", route1["gateway"]) + domains := route1["domains"].([]any) + assert.True(t, strings.HasSuffix(domains[0].(string), ".domain")) + assert.True(t, strings.HasSuffix(domains[1].(string), ".domain")) + + // Check map keys are anonymized + refCountMap := routeState["refCountMap"].(map[string]any) + hasPublicIPKey := false + hasIPv6Key := false + hasPrivateIPKey := false + for key := range refCountMap { + if strings.Contains(key, "203.0.113.1") { + hasPublicIPKey = true + } + if strings.Contains(key, "2001:db8::1") { + hasIPv6Key = true + } + if key == "10.0.0.1/32" { + hasPrivateIPKey = true + } + } + assert.False(t, hasPublicIPKey, "public IP in key should be anonymized") + assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized") + assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged") +} + +func mustMarshal(v any) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return data +} + +func TestAnonymizeNetworkMap(t *testing.T) { + networkMap := &mgmProto.NetworkMap{ + PeerConfig: &mgmProto.PeerConfig{ + Address: "203.0.113.5", + Dns: "1.2.3.4", + Fqdn: "peer1.corp.example.com", + SshConfig: &mgmProto.SSHConfig{ + SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."), + }, + }, + RemotePeers: []*mgmProto.RemotePeerConfig{ + { + AllowedIps: []string{ + "203.0.113.1/32", + "2001:db8:1234::1/128", + "192.168.1.1/32", + "100.64.0.1/32", + "10.0.0.1/32", + }, + Fqdn: "peer2.corp.example.com", + SshConfig: &mgmProto.SSHConfig{ + SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."), + }, + }, + }, + Routes: []*mgmProto.Route{ + { + Network: "197.51.100.0/24", + Domains: []string{"prod.example.com", "staging.example.com"}, + NetID: "net-123abc", + }, + }, + DNSConfig: &mgmProto.DNSConfig{ + NameServerGroups: []*mgmProto.NameServerGroup{ + { + NameServers: []*mgmProto.NameServer{ + {IP: "8.8.8.8"}, + {IP: "1.1.1.1"}, + {IP: "203.0.113.53"}, + }, + Domains: []string{"example.com", "internal.example.com"}, + }, + }, + CustomZones: []*mgmProto.CustomZone{ + { + Domain: "custom.example.com", + Records: []*mgmProto.SimpleRecord{ + { + Name: "www.custom.example.com", + Type: 1, + RData: "203.0.113.10", + }, + { + Name: "internal.custom.example.com", + Type: 1, + RData: "192.168.1.10", + }, + }, + }, + }, + }, + } + + // Create anonymizer with test addresses + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Anonymize the network map + err := anonymizeNetworkMap(networkMap, anonymizer) + require.NoError(t, err) + + // Test PeerConfig anonymization + peerCfg := networkMap.PeerConfig + require.NotEqual(t, "203.0.113.5", peerCfg.Address) + + // Verify DNS and FQDN are properly anonymized + require.NotEqual(t, "1.2.3.4", peerCfg.Dns) + require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn) + require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain")) + + // Verify SSH key is replaced + require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey) + + // Test RemotePeers anonymization + remotePeer := networkMap.RemotePeers[0] + + // Verify FQDN is anonymized + require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn) + require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain")) + + // Check that public IPs are anonymized but private IPs are preserved + for _, allowedIP := range remotePeer.AllowedIps { + ip, _, err := net.ParseCIDR(allowedIP) + require.NoError(t, err) + + if ip.IsPrivate() || isInCGNATRange(ip) { + require.Contains(t, []string{ + "192.168.1.1/32", + "100.64.0.1/32", + "10.0.0.1/32", + }, allowedIP) + } else { + require.NotContains(t, []string{ + "203.0.113.1/32", + "2001:db8:1234::1/128", + }, allowedIP) + } + } + + // Test Routes anonymization + route := networkMap.Routes[0] + require.NotEqual(t, "197.51.100.0/24", route.Network) + for _, domain := range route.Domains { + require.True(t, strings.HasSuffix(domain, ".domain")) + require.NotContains(t, domain, "example.com") + } + + // Test DNS config anonymization + dnsConfig := networkMap.DNSConfig + nameServerGroup := dnsConfig.NameServerGroups[0] + + // Verify well-known DNS servers are preserved + require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP) + require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP) + + // Verify public DNS server is anonymized + require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP) + + // Verify domains are anonymized + for _, domain := range nameServerGroup.Domains { + require.True(t, strings.HasSuffix(domain, ".domain")) + require.NotContains(t, domain, "example.com") + } + + // Test CustomZones anonymization + customZone := dnsConfig.CustomZones[0] + require.True(t, strings.HasSuffix(customZone.Domain, ".domain")) + require.NotContains(t, customZone.Domain, "example.com") + + // Verify records are properly anonymized + for _, record := range customZone.Records { + require.True(t, strings.HasSuffix(record.Name, ".domain")) + require.NotContains(t, record.Name, "example.com") + + ip := net.ParseIP(record.RData) + if ip != nil { + if !ip.IsPrivate() { + require.NotEqual(t, "203.0.113.10", record.RData) + } else { + require.Equal(t, "192.168.1.10", record.RData) + } + } + } +} + +// Helper function to check if IP is in CGNAT range +func isInCGNATRange(ip net.IP) bool { + cgnat := net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + return cgnat.Contains(ip) +} + +func TestAnonymizeFirewallRules(t *testing.T) { + // TODO: Add ipv6 + + // Example iptables-save output + iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s 192.168.1.0/24 -j ACCEPT +-A INPUT -s 44.192.140.1/32 -j DROP +-A FORWARD -s 10.0.0.0/8 -j DROP +-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT +COMMIT + +*nat +:PREROUTING ACCEPT [0:0] +:INPUT ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +:POSTROUTING ACCEPT [0:0] +-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE +-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80 +COMMIT` + + // Example iptables -v -n -L output + iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0 + 100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0 + +Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0 + 25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24 + +Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination` + + // Example nftables output + nftablesRules := `table inet filter { + chain input { + type filter hook input priority filter; policy accept; + ip saddr 192.168.1.1 accept + ip saddr 44.192.140.1 drop + } + chain forward { + type filter hook forward priority filter; policy accept; + ip saddr 10.0.0.0/8 drop + ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + } + }` + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Test iptables-save anonymization + anonIptablesSave := anonymizer.AnonymizeString(iptablesSave) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesSave, "192.168.1.0/24") + assert.Contains(t, anonIptablesSave, "10.0.0.0/8") + assert.Contains(t, anonIptablesSave, "192.168.100.0/24") + assert.Contains(t, anonIptablesSave, "192.168.1.10") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesSave, "44.192.140.1") + assert.NotContains(t, anonIptablesSave, "44.192.140.0/24") + assert.NotContains(t, anonIptablesSave, "52.84.12.34") + assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonIptablesSave, "*filter") + assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]") + assert.Contains(t, anonIptablesSave, "COMMIT") + assert.Contains(t, anonIptablesSave, "-j MASQUERADE") + assert.Contains(t, anonIptablesSave, "--dport 80") + + // Test iptables verbose output anonymization + anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24") + assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesVerbose, "44.192.140.1") + assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24") + assert.NotContains(t, anonIptablesVerbose, "52.84.12.34") + assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range + + // Structure and counters should be preserved + assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)") + assert.Contains(t, anonIptablesVerbose, "100 1024 DROP") + assert.Contains(t, anonIptablesVerbose, "pkts bytes target") + + // Test nftables anonymization + anonNftables := anonymizer.AnonymizeString(nftablesRules) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonNftables, "192.168.1.1") + assert.Contains(t, anonNftables, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonNftables, "44.192.140.1") + assert.NotContains(t, anonNftables, "44.192.140.0/24") + assert.NotContains(t, anonNftables, "52.84.12.34") + assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonNftables, "table inet filter {") + assert.Contains(t, anonNftables, "chain input {") + assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") +} diff --git a/client/server/route.go b/client/server/network.go similarity index 58% rename from client/server/route.go rename to client/server/network.go index d70e0dca3..aaf361524 100644 --- a/client/server/route.go +++ b/client/server/network.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "slices" "sort" "golang.org/x/exp/maps" @@ -20,8 +21,8 @@ type selectRoute struct { Selected bool } -// ListRoutes returns a list of all available routes. -func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { +// ListNetworks returns a list of all available networks. +func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*proto.ListNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -34,7 +35,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() + routesMap := engine.GetRouteManager().GetClientRoutesWithNetID() routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute @@ -67,37 +68,47 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L }) resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() - var pbRoutes []*proto.Route + var pbRoutes []*proto.Network for _, route := range routes { - pbRoute := &proto.Route{ + pbRoute := &proto.Network{ ID: string(route.NetID), - Network: route.Network.String(), + Range: route.Network.String(), Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, } - for _, domain := range route.Domains { - if prefixes, exists := resolvedDomains[domain]; exists { - var ipStrings []string - for _, prefix := range prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ - Ips: ipStrings, + // Group resolved IPs by their parent domain + domainMap := map[domain.Domain][]string{} + + for resolvedDomain, info := range resolvedDomains { + // Check if this resolved domain's parent is in our route's domains + if slices.Contains(route.Domains, info.ParentDomain) { + ips := make([]string, 0, len(info.Prefixes)) + for _, prefix := range info.Prefixes { + ips = append(ips, prefix.Addr().String()) } + domainMap[resolvedDomain] = ips } } + + // Convert to proto format + for domain, ips := range domainMap { + pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + Ips: ips, + } + } + pbRoutes = append(pbRoutes, pbRoute) } - return &proto.ListRoutesResponse{ + return &proto.ListNetworksResponse{ Routes: pbRoutes, }, nil } -// SelectRoutes selects specific routes based on the client request. -func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// SelectNetworks selects specific networks based on the client request. +func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -115,18 +126,19 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) if req.GetAll() { routeSelector.SelectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } -// DeselectRoutes deselects specific routes based on the client request. -func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// DeselectNetworks deselects specific networks based on the client request. +func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -144,14 +156,15 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques if req.GetAll() { routeSelector.DeselectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } func toNetIDs(routes []string) []route.NetID { diff --git a/client/server/server.go b/client/server/server.go index 106bdf32b..70d19bfab 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -68,6 +68,8 @@ type Server struct { relayProbe *internal.Probe wgProbe *internal.Probe lastProbe time.Time + + persistNetworkMap bool } type oauthAuthFlow struct { @@ -89,6 +91,8 @@ func New(ctx context.Context, configPath, logFile string) *Server { signalProbe: internal.NewProbe(), relayProbe: internal.NewProbe(), wgProbe: internal.NewProbe(), + + persistNetworkMap: true, } } @@ -196,6 +200,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf runOperation := func() error { log.Tracef("running client connection") s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap) probes := internal.ProbeHolder{ MgmProbe: s.mgmProbe, @@ -394,6 +399,23 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.DNSRouteInterval = &duration } + if msg.DisableClientRoutes != nil { + inputConfig.DisableClientRoutes = msg.DisableClientRoutes + s.latestConfigInput.DisableClientRoutes = msg.DisableClientRoutes + } + if msg.DisableServerRoutes != nil { + inputConfig.DisableServerRoutes = msg.DisableServerRoutes + s.latestConfigInput.DisableServerRoutes = msg.DisableServerRoutes + } + if msg.DisableDns != nil { + inputConfig.DisableDNS = msg.DisableDns + s.latestConfigInput.DisableDNS = msg.DisableDns + } + if msg.DisableFirewall != nil { + inputConfig.DisableFirewall = msg.DisableFirewall + s.latestConfigInput.DisableFirewall = msg.DisableFirewall + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { @@ -769,7 +791,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ @@ -788,7 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.GetRoutes()), + Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) diff --git a/client/server/server_test.go b/client/server/server_test.go index 61bdaf660..128de8e02 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -20,6 +20,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -110,7 +112,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } @@ -132,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/state.go b/client/server/state.go index 509782e86..222c7c7bd 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -5,12 +5,112 @@ import ( "fmt" "github.com/hashicorp/go-multierror" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" ) -// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// ListStates returns a list of all saved states +func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) { + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + + stateNames, err := mgr.GetSavedStateNames() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err) + } + + states := make([]*proto.State, 0, len(stateNames)) + for _, name := range stateNames { + states = append(states, &proto.State{ + Name: name, + }) + } + + return &proto.ListStatesResponse{ + States: states, + }, nil +} + +// CleanState handles cleaning of states (performing cleanup operations) +func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) { + if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") + } + + if req.All { + // Reuse existing cleanup logic for all states + if err := restoreResidualState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err) + } + + // Get count of cleaned states + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + stateNames, err := mgr.GetSavedStateNames() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err) + } + + return &proto.CleanStateResponse{ + CleanedStates: int32(len(stateNames)), + }, nil + } + + // Handle single state cleanup + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + registerStates(mgr) + + if err := mgr.CleanupStateByName(req.StateName); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err) + } + + if err := mgr.PersistState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err) + } + + return &proto.CleanStateResponse{ + CleanedStates: 1, + }, nil +} + +// DeleteState handles deletion of states without cleanup +func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) { + if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") + } + + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + + var count int + var err error + + if req.All { + count, err = mgr.DeleteAllStates() + } else { + err = mgr.DeleteStateByName(req.StateName) + if err == nil { + count = 1 + } + } + + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err) + } + + // Persist the changes + if err := mgr.PersistState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err) + } + + return &proto.DeleteStateResponse{ + DeletedStates: int32(count), + }, nil +} + +// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. func restoreResidualState(ctx context.Context) error { path := statemanager.GetDefaultStatePath() @@ -24,6 +124,7 @@ func restoreResidualState(ctx context.Context) error { registerStates(mgr) var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) } diff --git a/client/system/info.go b/client/system/info.go index 2af2e637b..200d835df 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -61,6 +61,14 @@ type Info struct { Files []File // for posture checks } +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment +} + // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index 6f4ed173b..13b0a446b 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -10,13 +10,12 @@ import ( "os/exec" "runtime" "strings" + "time" "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, prodName, manufacturer := sysInfo() - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: release, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() diff --git a/client/system/info_linux.go b/client/system/info_linux.go index b6a142bce..bfc77be19 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -1,5 +1,4 @@ //go:build !android -// +build !android package system @@ -16,30 +15,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/zcalusic/sysinfo" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) -type SysInfoGetter interface { - GetSysInfo() SysInfo -} - -type SysInfoWrapper struct { - si sysinfo.SysInfo -} - -func (s SysInfoWrapper) GetSysInfo() SysInfo { - s.si.GetSysInfo() - return SysInfo{ - ChassisSerial: s.si.Chassis.Serial, - ProductSerial: s.si.Product.Serial, - BoardSerial: s.si.Board.Serial, - ProductName: s.si.Product.Name, - BoardName: s.si.Board.Name, - ProductVendor: s.si.Product.Vendor, - } -} +var ( + // it is override in tests + getSystemInfo = defaultSysInfoImplementation +) // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { @@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - si := SysInfoWrapper{} - serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info { UIVersion: extractUserAgent(ctx), KernelVersion: osInfo[1], NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } return gio @@ -108,9 +88,9 @@ func _getInfo() string { return out.String() } -func sysInfo(si SysInfo) (string, string, string) { +func sysInfo() (string, string, string) { isascii := regexp.MustCompile("^[[:ascii:]]+$") - + si := getSystemInfo() serials := []string{si.ChassisSerial, si.ProductSerial} serial := "" @@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) { } return serial, name, manufacturer } + +func defaultSysInfoImplementation() SysInfo { + si := sysinfo.SysInfo{} + si.GetSysInfo() + return SysInfo{ + ChassisSerial: si.Chassis.Serial, + ProductSerial: si.Product.Serial, + BoardSerial: si.Board.Serial, + ProductName: si.Product.Name, + BoardName: si.Board.Name, + ProductVendor: si.Product.Vendor, + } +} diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 68631fe16..28bd3d300 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -6,13 +6,12 @@ import ( "os" "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" "golang.org/x/sys/windows/registry" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, err := sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - prodName, err := sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err := sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: buildVersion, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() @@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info { return gio } +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + func getOSNameAndVersion() (string, string) { var dst []Win32_OperatingSystem query := wmi.CreateQuery(&dst, "") diff --git a/client/system/static_info.go b/client/system/static_info.go new file mode 100644 index 000000000..fabe65a68 --- /dev/null +++ b/client/system/static_info.go @@ -0,0 +1,46 @@ +//go:build (linux && !android) || windows || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +var ( + staticInfo StaticInfo + once sync.Once +) + +func init() { + go func() { + _ = updateStaticInfo() + }() +} + +func updateStaticInfo() StaticInfo { + once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + staticInfo.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + }) + return staticInfo +} diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go index f6a0b7058..ae89bfcf9 100644 --- a/client/system/sysinfo_linux_test.go +++ b/client/system/sysinfo_linux_test.go @@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) + getSystemInfo = func() SysInfo { + return tt.sysInfo + } + gotSerialNum, gotProdName, gotManufacturer := sysInfo() if gotSerialNum != tt.wantSerialNum { t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d046bab5f..49b0f53cf 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -58,7 +58,7 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") var showRoutes bool - flag.BoolVar(&showRoutes, "routes", false, "run routes windows") + flag.BoolVar(&showRoutes, "networks", false, "run networks windows") var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") @@ -233,7 +233,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo s.showSettingsUI() return s } else if showRoutes { - s.showRoutesUI() + s.showNetworksUI() } return s @@ -549,7 +549,7 @@ func (s *serviceClient) onTrayReady() { s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") s.loadSettings() - s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") + s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window") s.mRoutes.Disable() systray.AddSeparator() @@ -572,6 +572,7 @@ func (s *serviceClient) onTrayReady() { s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() + time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { err := s.updateStatus() if err != nil { @@ -656,7 +657,7 @@ func (s *serviceClient) onTrayReady() { s.mRoutes.Disable() go func() { defer s.mRoutes.Enable() - s.runSelfCommand("routes", "true") + s.runSelfCommand("networks", "true") }() } if err != nil { diff --git a/client/ui/route.go b/client/ui/network.go similarity index 56% rename from client/ui/route.go rename to client/ui/network.go index 5b6b8fee0..e6f027f0e 100644 --- a/client/ui/route.go +++ b/client/ui/network.go @@ -19,32 +19,32 @@ import ( ) const ( - allRoutesText = "All routes" - overlappingRoutesText = "Overlapping routes" - exitNodeRoutesText = "Exit-node routes" - allRoutes filter = "all" - overlappingRoutes filter = "overlapping" - exitNodeRoutes filter = "exit-node" - getClientFMT = "get client: %v" + allNetworksText = "All networks" + overlappingNetworksText = "Overlapping networks" + exitNodeNetworksText = "Exit-node networks" + allNetworks filter = "all" + overlappingNetworks filter = "overlapping" + exitNodeNetworks filter = "exit-node" + getClientFMT = "get client: %v" ) type filter string -func (s *serviceClient) showRoutesUI() { - s.wRoutes = s.app.NewWindow("NetBird Routes") +func (s *serviceClient) showNetworksUI() { + s.wRoutes = s.app.NewWindow("Networks") allGrid := container.New(layout.NewGridLayout(3)) - go s.updateRoutes(allGrid, allRoutes) + go s.updateNetworks(allGrid, allNetworks) overlappingGrid := container.New(layout.NewGridLayout(3)) exitNodeGrid := container.New(layout.NewGridLayout(3)) routeCheckContainer := container.NewVBox() tabs := container.NewAppTabs( - container.NewTabItem(allRoutesText, allGrid), - container.NewTabItem(overlappingRoutesText, overlappingGrid), - container.NewTabItem(exitNodeRoutesText, exitNodeGrid), + container.NewTabItem(allNetworksText, allGrid), + container.NewTabItem(overlappingNetworksText, overlappingGrid), + container.NewTabItem(exitNodeNetworksText, exitNodeGrid), ) tabs.OnSelected = func(item *container.TabItem) { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) } tabs.OnUnselected = func(item *container.TabItem) { grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) @@ -58,17 +58,17 @@ func (s *serviceClient) showRoutesUI() { buttonBox := container.NewHBox( layout.NewSpacer(), widget.NewButton("Refresh", func() { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Select all", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.selectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.selectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Deselect All", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.deselectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.deselectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), layout.NewSpacer(), ) @@ -81,36 +81,36 @@ func (s *serviceClient) showRoutesUI() { s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } -func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { +func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Objects = nil grid.Refresh() idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) - networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Range/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) grid.Add(idHeader) grid.Add(networkHeader) grid.Add(resolvedIPsHeader) - filteredRoutes, err := s.getFilteredRoutes(f) + filteredRoutes, err := s.getFilteredNetworks(f) if err != nil { return } - sortRoutesByIDs(filteredRoutes) + sortNetworksByIDs(filteredRoutes) for _, route := range filteredRoutes { r := route checkBox := widget.NewCheck(r.GetID(), func(checked bool) { - s.selectRoute(r.ID, checked) + s.selectNetwork(r.ID, checked) }) checkBox.Checked = route.Selected checkBox.Resize(fyne.NewSize(20, 20)) checkBox.Refresh() grid.Add(checkBox) - network := r.GetNetwork() + network := r.GetRange() domains := r.GetDomains() if len(domains) == 0 { @@ -129,10 +129,8 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Add(domainsSelector) var resolvedIPsList []string - for _, domain := range domains { - if ipList, exists := r.GetResolvedIPs()[domain]; exists { - resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) - } + for domain, ipList := range r.GetResolvedIPs() { + resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) } if len(resolvedIPsList) == 0 { @@ -151,35 +149,35 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Refresh() } -func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) { - routes, err := s.fetchRoutes() +func (s *serviceClient) getFilteredNetworks(f filter) ([]*proto.Network, error) { + routes, err := s.fetchNetworks() if err != nil { log.Errorf(getClientFMT, err) s.showError(fmt.Errorf(getClientFMT, err)) return nil, err } switch f { - case overlappingRoutes: - return getOverlappingRoutes(routes), nil - case exitNodeRoutes: - return getExitNodeRoutes(routes), nil + case overlappingNetworks: + return getOverlappingNetworks(routes), nil + case exitNodeNetworks: + return getExitNodeNetworks(routes), nil default: } return routes, nil } -func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route - existingRange := make(map[string][]*proto.Route) +func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network + existingRange := make(map[string][]*proto.Network) for _, route := range routes { if len(route.Domains) > 0 { continue } - if r, exists := existingRange[route.GetNetwork()]; exists { + if r, exists := existingRange[route.GetRange()]; exists { r = append(r, route) - existingRange[route.GetNetwork()] = r + existingRange[route.GetRange()] = r } else { - existingRange[route.GetNetwork()] = []*proto.Route{route} + existingRange[route.GetRange()] = []*proto.Network{route} } } for _, r := range existingRange { @@ -190,29 +188,29 @@ func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { return filteredRoutes } -func getExitNodeRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route +func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network for _, route := range routes { - if route.Network == "0.0.0.0/0" { + if route.Range == "0.0.0.0/0" { filteredRoutes = append(filteredRoutes, route) } } return filteredRoutes } -func sortRoutesByIDs(routes []*proto.Route) { +func sortNetworksByIDs(routes []*proto.Network) { sort.Slice(routes, func(i, j int) bool { return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID()) }) } -func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { +func (s *serviceClient) fetchNetworks() ([]*proto.Network, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { return nil, fmt.Errorf(getClientFMT, err) } - resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) + resp, err := conn.ListNetworks(s.ctx, &proto.ListNetworksRequest{}) if err != nil { return nil, fmt.Errorf("failed to list routes: %v", err) } @@ -220,7 +218,7 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { return resp.Routes, nil } -func (s *serviceClient) selectRoute(id string, checked bool) { +func (s *serviceClient) selectNetwork(id string, checked bool) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) @@ -228,73 +226,73 @@ func (s *serviceClient) selectRoute(id string, checked bool) { return } - req := &proto.SelectRoutesRequest{ - RouteIDs: []string{id}, - Append: checked, + req := &proto.SelectNetworksRequest{ + NetworkIDs: []string{id}, + Append: checked, } if checked { - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select route: %v", err) - s.showError(fmt.Errorf("failed to select route: %v", err)) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select network: %v", err) + s.showError(fmt.Errorf("failed to select network: %v", err)) return } log.Infof("Route %s selected", id) } else { - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect route: %v", err) - s.showError(fmt.Errorf("failed to deselect route: %v", err)) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect network: %v", err) + s.showError(fmt.Errorf("failed to deselect network: %v", err)) return } - log.Infof("Route %s deselected", id) + log.Infof("Network %s deselected", id) } } -func (s *serviceClient) selectAllFilteredRoutes(f filter) { +func (s *serviceClient) selectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, true) - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select all routes: %v", err) - s.showError(fmt.Errorf("failed to select all routes: %v", err)) + req := s.getNetworksRequest(f, true) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select all networks: %v", err) + s.showError(fmt.Errorf("failed to select all networks: %v", err)) return } - log.Debug("All routes selected") + log.Debug("All networks selected") } -func (s *serviceClient) deselectAllFilteredRoutes(f filter) { +func (s *serviceClient) deselectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, false) - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect all routes: %v", err) - s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) + req := s.getNetworksRequest(f, false) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect all networks: %v", err) + s.showError(fmt.Errorf("failed to deselect all networks: %v", err)) return } - log.Debug("All routes deselected") + log.Debug("All networks deselected") } -func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest { - req := &proto.SelectRoutesRequest{} - if f == allRoutes { +func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.SelectNetworksRequest { + req := &proto.SelectNetworksRequest{} + if f == allNetworks { req.All = true } else { - routes, err := s.getFilteredRoutes(f) + routes, err := s.getFilteredNetworks(f) if err != nil { return nil } for _, route := range routes { - req.RouteIDs = append(req.RouteIDs, route.GetID()) + req.NetworkIDs = append(req.NetworkIDs, route.GetID()) } req.Append = appendRoute } @@ -311,7 +309,7 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container ticker := time.NewTicker(interval) go func() { for range ticker.C { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) } }() @@ -320,20 +318,20 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container }) } -func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { +func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) s.wRoutes.Content().Refresh() - s.updateRoutes(grid, f) + s.updateNetworks(grid, f) } func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { switch tabs.Selected().Text { - case overlappingRoutesText: - return overlappingGrid, overlappingRoutes - case exitNodeRoutesText: - return exitNodesGrid, exitNodeRoutes + case overlappingNetworksText: + return overlappingGrid, overlappingNetworks + case exitNodeNetworksText: + return exitNodesGrid, exitNodeNetworks default: - return allGrid, allRoutes + return allGrid, allNetworks } } diff --git a/dns/dns.go b/dns/dns.go index 18528c743..8dfdf8526 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) { return validHost, nil } + +// NormalizeZone returns a normalized domain name without the wildcard prefix +func NormalizeZone(domain string) string { + d, _ := strings.CutPrefix(domain, "*.") + return d +} diff --git a/go.mod b/go.mod index 0a16753ea..74b160b50 100644 --- a/go.mod +++ b/go.mod @@ -19,18 +19,18 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.28.0 - golang.org/x/sys v0.26.0 + golang.org/x/crypto v0.31.0 + golang.org/x/sys v0.28.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.34.1 + google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) require ( - fyne.io/fyne/v2 v2.5.0 + fyne.io/fyne/v2 v2.5.3 fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -77,10 +77,11 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.31.0 + github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.4 - github.com/zcalusic/sysinfo v1.0.2 + github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 @@ -92,13 +93,14 @@ require ( golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.8.0 - golang.org/x/term v0.25.0 + golang.org/x/sync v0.10.0 + golang.org/x/term v0.27.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.3 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.7 nhooyr.io/websocket v1.8.11 ) @@ -107,6 +109,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/BurntSushi/toml v1.4.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect @@ -143,7 +146,7 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.0 // indirect github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect - github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect + github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect @@ -151,8 +154,9 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect - github.com/go-text/render v0.1.0 // indirect - github.com/go-text/typesetting v0.1.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-text/render v0.2.0 // indirect + github.com/go-text/typesetting v0.2.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect @@ -202,11 +206,12 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/procfs v0.15.0 // indirect - github.com/rymdport/portal v0.2.2 // indirect + github.com/rymdport/portal v0.3.0 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect @@ -219,12 +224,12 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect @@ -236,7 +241,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index a4d7ea7f9..f6d4590ee 100644 --- a/go.sum +++ b/go.sum @@ -48,8 +48,10 @@ cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJpl dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -fyne.io/fyne/v2 v2.5.0 h1:lEjEIso0Vi4sJXYngIMoXOM6aUjqnPjK7pBpxRxG9aI= -fyne.io/fyne/v2 v2.5.0/go.mod h1:9D4oT3NWeG+MLi/lP7ItZZyujHC/qqMJpoGTAYX5Uqc= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +fyne.io/fyne/v2 v2.5.3 h1:k6LjZx6EzRZhClsuzy6vucLZBstdH2USDGHSGWq8ly8= +fyne.io/fyne/v2 v2.5.3/go.mod h1:0GOXKqyvNwk3DLmsFu9v0oYM0ZcD1ysGnlHCerKoAmo= fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= @@ -202,8 +204,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe h1:A/wiwvQ0CAjPkuJytaD+SsXkPU0asQ+guQEIg1BJGX4= github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe/go.mod h1:d4clgH0/GrRwWjRzJJQXxT/h1TyuNSfF/X64zb/3Ggg= -github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a h1:ybgRdYvAHTn93HW79bLiBiJwVL4jVeyGQRZMgImoeWs= -github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY= +github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 h1:/1YRWFv9bAWkoo3SuxpFfzpXH0D/bQnTjNXyF4ih7Os= +github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY= github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 h1:hnLq+55b7Zh7/2IRzWCpiTcAvjv/P8ERF+N7+xXbZhk= github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2/go.mod h1:eO7W361vmlPOrykIg+Rsh1SZ3tQBaOsfzZhsIOb/Lm0= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -238,15 +240,18 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7 github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/go-text/render v0.1.0 h1:osrmVDZNHuP1RSu3pNG7Z77Sd2xSbcb/xWytAj9kyVs= -github.com/go-text/render v0.1.0/go.mod h1:jqEuNMenrmj6QRnkdpeaP0oKGFLDNhDkVKwGjsWWYU4= -github.com/go-text/typesetting v0.1.0 h1:vioSaLPYcHwPEPLT7gsjCGDCoYSbljxoHJzMnKwVvHw= -github.com/go-text/typesetting v0.1.0/go.mod h1:d22AnmeKq/on0HNv73UFriMKc4Ez6EqZAofLhAzpSzI= -github.com/go-text/typesetting-utils v0.0.0-20240329101916-eee87fb235a3 h1:levTnuLLUmpavLGbJYLJA7fQnKeS7P1eCdAlM+vReXk= -github.com/go-text/typesetting-utils v0.0.0-20240329101916-eee87fb235a3/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= +github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= +github.com/go-text/typesetting v0.2.0 h1:fbzsgbmk04KiWtE+c3ZD4W2nmCRzBqrqQOvYlwAOdho= +github.com/go-text/typesetting v0.2.0/go.mod h1:2+owI/sxa73XA581LAzVuEBZ3WEEV2pXeDswCH/3i1I= +github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66 h1:GUrm65PQPlhFSKjLPGOZNPNxLCybjzjYBzjfoBGaDUY= +github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -521,14 +526,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -616,8 +621,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/rymdport/portal v0.2.2 h1:P2Q/4k673zxdFAsbD8EESZ7psfuO6/4jNu6EDrDICkM= -github.com/rymdport/portal v0.2.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= +github.com/rymdport/portal v0.3.0 h1:QRHcwKwx3kY5JTQcsVhmhC3TGqGQb9LFghVNUy8AdB8= +github.com/rymdport/portal v0.3.0/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= @@ -662,6 +667,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -680,6 +686,8 @@ github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8 github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U= github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI= +github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk= +github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0/go.mod h1:REFmO+lSG9S6uSBEwIMZCxeI36uhScjTwChYADeO3JA= github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3KNKRbJMbWv+wolWqOFUECmjYZ+sIRZCIBc/E= github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw= github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0= @@ -708,8 +716,8 @@ github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= -github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= -github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= +github.com/zcalusic/sysinfo v1.1.3 h1:u/AVENkuoikKuIZ4sUEJ6iibpmQP6YpGD8SSMCrqAF0= +github.com/zcalusic/sysinfo v1.1.3/go.mod h1:NX+qYnWGtJVPV0yWldff9uppNKU4h40hJIRPf/pGLv4= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= @@ -774,8 +782,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -901,8 +909,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +982,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +991,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +1007,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1151,8 +1159,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1189,8 +1197,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1224,12 +1232,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index ff33004b2..d02e4f40c 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -53,6 +53,18 @@ if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then export NETBIRD_STORE_ENGINE_POSTGRES_DSN fi +# Check if MySQL is set as the store engine +if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then + # Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set + if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then + echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set." + echo "Please add the following line to your setup.env file:" + echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN=":@tcp(127.0.0.1:3306)/"' + exit 1 + fi + export NETBIRD_STORE_ENGINE_MYSQL_DSN +fi + # local development or tests if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted" diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index ba68b3f8d..b7904fb5b 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -96,6 +96,7 @@ services: max-file: "2" environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index c4415d848..71471c3ef 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -50,6 +50,24 @@ services: - traefik.http.services.netbird-signal.loadbalancer.server.port=80 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c + # Relay + relay: + image: netbirdio/relay:$NETBIRD_RELAY_TAG + restart: unless-stopped + environment: + - NB_LOG_LEVEL=info + - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT + - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT + # todo: change to a secure secret + - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET + ports: + - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" + # Management management: image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG @@ -83,6 +101,7 @@ services: - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 0b2b65142..9b80058c2 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c + protocols h1 h2c h2 h3 } } @@ -788,6 +788,7 @@ services: networks: [ netbird ] ports: - '443:443' + - '443:443/udp' - '80:80' - '8080:8080' volumes: diff --git a/management/client/client_test.go b/management/client/client_test.go index 100b3fcaa..8bd8af8d2 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/client/system" @@ -57,7 +59,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -76,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 719d1a78c..1c8fca8dc 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -41,11 +41,20 @@ import ( "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -149,7 +158,7 @@ var ( if err != nil { return err } - store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) + store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -257,14 +266,22 @@ var ( return fmt.Errorf("failed creating JWT validator: %v", err) } - httpAPIAuthCfg := httpapi.AuthCfg{ + httpAPIAuthCfg := configs.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) + groupsManager := groups.NewManager(store, permissionsManager, accountManager) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) + routersManager := routers.NewManager(store, permissionsManager, accountManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) + + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -273,7 +290,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } @@ -399,7 +416,7 @@ func notifyStop(ctx context.Context, msg string) { } } -func getInstallationID(ctx context.Context, store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 7aa11f0c9..183fc554d 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -9,7 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -32,7 +32,7 @@ var upCmd = &cobra.Command{ //nolint ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) - if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { + if err := store.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { return err } log.WithContext(ctx).Info("Migration finished successfully") diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 672b2a102..b4ff16e6d 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v4.24.3 // source: management.proto package proto @@ -29,6 +29,7 @@ const ( RuleProtocol_TCP RuleProtocol = 2 RuleProtocol_UDP RuleProtocol = 3 RuleProtocol_ICMP RuleProtocol = 4 + RuleProtocol_CUSTOM RuleProtocol = 5 ) // Enum value maps for RuleProtocol. @@ -39,6 +40,7 @@ var ( 2: "TCP", 3: "UDP", 4: "ICMP", + 5: "CUSTOM", } RuleProtocol_value = map[string]int32{ "UNKNOWN": 0, @@ -46,6 +48,7 @@ var ( "TCP": 2, "UDP": 3, "ICMP": 4, + "CUSTOM": 5, } ) @@ -1393,7 +1396,8 @@ type PeerConfig struct { // SSHConfig of the peer. SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"` // Peer fully qualified domain name - Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` } func (x *PeerConfig) Reset() { @@ -1456,6 +1460,13 @@ func (x *PeerConfig) GetFqdn() string { return "" } +func (x *PeerConfig) GetRoutingPeerDnsResolutionEnabled() bool { + if x != nil { + return x.RoutingPeerDnsResolutionEnabled + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -2780,6 +2791,10 @@ type RouteFirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` // IsDynamic indicates if the route is a DNS route. IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` + // Domains is a list of domains for which the rule is applicable. + Domains []string `protobuf:"bytes,7,rep,name=domains,proto3" json:"domains,omitempty"` + // CustomProtocol is a custom protocol ID. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -2856,6 +2871,20 @@ func (x *RouteFirewallRule) GetIsDynamic() bool { return false } +func (x *RouteFirewallRule) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *RouteFirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3075,7 +3104,7 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xcb, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, @@ -3083,250 +3112,260 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, - 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, - 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, - 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, - 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, - 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, + 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, + 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, + 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, - 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, - 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, - 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, - 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, - 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, - 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, - 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, - 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, - 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, - 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, - 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, - 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, - 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, - 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, - 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, - 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, - 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, - 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, - 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, - 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, - 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, - 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, - 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, - 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, - 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, - 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, - 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, - 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, - 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, - 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, - 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, - 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, - 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, - 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, - 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, - 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, - 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, - 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, - 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, - 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, - 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, + 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, + 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, + 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, + 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, + 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, + 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, + 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, + 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, + 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, + 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, + 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, + 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, + 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, + 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, + 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, + 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, + 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, + 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, + 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xd1, 0x02, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, + 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, + 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, + 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, + 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, + 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, + 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, + 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, + 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, + 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index fe6a828b1..5f4e0df46 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -222,6 +222,8 @@ message PeerConfig { SSHConfig sshConfig = 3; // Peer fully qualified domain name string fqdn = 4; + + bool RoutingPeerDnsResolutionEnabled = 5; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -396,6 +398,7 @@ enum RuleProtocol { TCP = 2; UDP = 3; ICMP = 4; + CUSTOM = 5; } enum RuleDirection { @@ -459,5 +462,11 @@ message RouteFirewallRule { // IsDynamic indicates if the route is a DNS route. bool isDynamic = 6; + + // Domains is a list of domains for which the rule is applicable. + repeated string domains = 7; + + // CustomProtocol is a custom protocol ID. + uint32 customProtocol = 8; } diff --git a/management/server/account.go b/management/server/account.go index 59c9c7fb0..5c23dba04 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -19,8 +19,6 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - "github.com/hashicorp/go-multierror" - "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -29,31 +27,27 @@ import ( "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour - DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + peerSchedulerRetryInterval = 3 * time.Second + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -66,57 +60,57 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) - GetAccount(ctx context.Context, accountID string) (*Account, error) - CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, - autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccount(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, + autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) - SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) - GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error) + GetAccountInfoFromPAT(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error - GetUserByID(ctx context.Context, id string) (*User, error) - GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetUserByID(ctx context.Context, id string) (*types.User, error) + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) - GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) - GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) - GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error - ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error @@ -130,12 +124,12 @@ type AccountManager interface { GetDNSDomain() string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -147,17 +141,18 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) - GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error + UpdateAccountPeers(ctx context.Context, accountID string) } type DefaultAccountManager struct { - Store Store + Store store.Store // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded @@ -168,7 +163,7 @@ type DefaultAccountManager struct { externalCacheManager ExternalCacheManager ctx context.Context eventStore activity.Store - geo *geolocation.Geolocation + geo geolocation.Geolocation requestBuffer *AccountRequestBuffer @@ -192,763 +187,40 @@ type DefaultAccountManager struct { metrics telemetry.AppMetrics } -// Settings represents Account settings structure that can be modified via API and Dashboard -type Settings struct { - // PeerLoginExpirationEnabled globally enables or disables peer login expiration - PeerLoginExpirationEnabled bool - - // PeerLoginExpiration is a setting that indicates when peer login expires. - // Applies to all peers that have Peer.LoginExpirationEnabled set to true. - PeerLoginExpiration time.Duration - - // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration - PeerInactivityExpirationEnabled bool - - // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. - // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. - PeerInactivityExpiration time.Duration - - // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements - RegularUsersViewBlocked bool - - // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer - GroupsPropagationEnabled bool - - // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName - // and add it to account groups. - JWTGroupsEnabled bool - - // JWTGroupsClaimName from which we extract groups name to add it to account groups - JWTGroupsClaimName string - - // JWTAllowGroups list of groups to which users are allowed access - JWTAllowGroups []string `gorm:"serializer:json"` - - // Extra is a dictionary of Account settings - Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` -} - -// Copy copies the Settings struct -func (s *Settings) Copy() *Settings { - settings := &Settings{ - PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, - PeerLoginExpiration: s.PeerLoginExpiration, - JWTGroupsEnabled: s.JWTGroupsEnabled, - JWTGroupsClaimName: s.JWTGroupsClaimName, - GroupsPropagationEnabled: s.GroupsPropagationEnabled, - JWTAllowGroups: s.JWTAllowGroups, - RegularUsersViewBlocked: s.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: s.PeerInactivityExpiration, - } - if s.Extra != nil { - settings.Extra = s.Extra.Copy() - } - return settings -} - -// Account represents a unique account of the system -type Account struct { - // we have to name column to aid as it collides with Network.Id when work with associations - Id string `gorm:"primaryKey"` - - // User.Id it was created by - CreatedBy string - CreatedAt time.Time - Domain string `gorm:"index"` - DomainCategory string - IsDomainPrimaryAccount bool - SetupKeys map[string]*SetupKey `gorm:"-"` - SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` - Network *Network `gorm:"embedded;embeddedPrefix:network_"` - Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` - Users map[string]*User `gorm:"-"` - UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*nbgroup.Group `gorm:"-"` - GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` - Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` - Routes map[route.ID]*route.Route `gorm:"-"` - RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` - NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` - NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` - PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load settings and not whole account -type AccountSettings struct { - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load network and not whole account -type AccountNetwork struct { - Network *Network `gorm:"embedded;embeddedPrefix:network_"` -} - -// AccountDNSSettings used in gorm to only load dns settings and not whole account -type AccountDNSSettings struct { - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` -} - -type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` - IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -// getRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(lookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - groupListMap := a.getPeerGroups(peerID) - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - // currently we support only linux routing peers - if peer.Meta.GoOS != "linux" { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - -// GetRoutesByPrefixOrDomains return list of routes by account and route prefix -func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { - var routes []*route.Route - for _, r := range a.Routes { - dynamic := r.IsDynamic() - if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || - !dynamic && r.Network.String() == prefix.String() { - routes = append(routes, r) - } - } - - return routes -} - -// GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *nbgroup.Group { - return a.Groups[groupID] -} - -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - validatedPeersMap map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) - routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) - } - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if metrics != nil { - objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", - a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) - } - } - - return nm -} - -func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { - var merr *multierror.Error - - if dnsDomain == "" { - log.WithContext(ctx).Error("no dns domain is set, returning empty zone") - return nbdns.CustomZone{} - } - - customZone := nbdns.CustomZone{ - Domain: dns.Fqdn(dnsDomain), - Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), - } - - domainSuffix := "." + dnsDomain - - var sb strings.Builder - for _, peer := range a.Peers { - if peer.DNSLabel == "" { - merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) - continue - } - - sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) - sb.WriteString(peer.DNSLabel) - sb.WriteString(domainSuffix) - - customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: defaultTTL, - RData: peer.IP.String(), - }) - - sb.Reset() - } - - go func() { - if merr != nil { - log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) - } - }() - - return customZone -} - -// GetExpiredPeers returns peers that have been expired -func (a *Account) GetExpiredPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.GetPeersWithExpiration() { - expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if expired { - peers = append(peers, peer) - } - } - - return peers -} - -// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are connected. -func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithExpiration() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - // consider only connected peers because others will require login on connecting to the management server - if peer.Status.LoginExpired || !peer.Status.Connected { - continue - } - _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetInactivePeers returns peers that have been expired by inactivity -func (a *Account) GetInactivePeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, inactivePeer := range a.GetPeersWithInactivity() { - inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) - if inactive { - peers = append(peers, inactivePeer) - } - } - return peers -} - -// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are not connected. -func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithInactivity() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - if peer.Status.LoginExpired || peer.Status.Connected { - continue - } - _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetPeers returns a list of all Account peers -func (a *Account) GetPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.Peers { - peers = append(peers, peer) - } - return peers -} - -// UpdateSettings saves new account settings -func (a *Account) UpdateSettings(update *Settings) *Account { - a.Settings = update.Copy() - return a -} - -// UpdatePeer saves new or replaces existing peer -func (a *Account) UpdatePeer(update *nbpeer.Peer) { - a.Peers[update.ID] = update -} - -// DeletePeer deletes peer from the account cleaning up all the references -func (a *Account) DeletePeer(peerID string) { - // delete peer from groups - for _, g := range a.Groups { - for i, pk := range g.Peers { - if pk == peerID { - g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) - break - } - } - } - - for _, r := range a.Routes { - if r.Peer == peerID { - r.Enabled = false - r.Peer = "" - } - } - - delete(a.Peers, peerID) - a.Network.IncSerial() -} - -// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. -// It will return an object copy of the peer. -func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { - for _, peer := range a.Peers { - if peer.Key == peerPubKey { - return peer.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) -} - -// FindUserPeers returns a list of peers that user owns (created) -func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.UserID == userID { - peers = append(peers, peer) - } - } - - return peers, nil -} - -// FindUser looks for a given user in the Account or returns error if user wasn't found. -func (a *Account) FindUser(userID string) (*User, error) { - user := a.Users[userID] - if user == nil { - return nil, status.Errorf(status.NotFound, "user %s not found", userID) - } - - return user, nil -} - -// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { - for _, group := range a.Groups { - if group.Name == groupName { - return group, nil - } - } - return nil, status.Errorf(status.NotFound, "group %s not found", groupName) -} - -// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. -func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { - key := a.SetupKeys[setupKey] - if key == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return key, nil -} - -// GetPeerGroupsList return with the list of groups ID. -func (a *Account) GetPeerGroupsList(peerID string) []string { - var grps []string - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - grps = append(grps, groupID) - break - } - } - } - return grps -} - -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.getPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (a *Account) getPeerGroups(peerID string) lookupMap { - groupList := make(lookupMap) - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - groupList[groupID] = struct{}{} - break - } - } - } - return groupList -} - -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - -func (a *Account) getPeerDNSLabels() lookupMap { - existingLabels := make(lookupMap) - for _, peer := range a.Peers { - if peer.DNSLabel != "" { - existingLabels[peer.DNSLabel] = struct{}{} - } - } - return existingLabels -} - -func (a *Account) Copy() *Account { - peers := map[string]*nbpeer.Peer{} - for id, peer := range a.Peers { - peers[id] = peer.Copy() - } - - users := map[string]*User{} - for id, user := range a.Users { - users[id] = user.Copy() - } - - setupKeys := map[string]*SetupKey{} - for id, key := range a.SetupKeys { - setupKeys[id] = key.Copy() - } - - groups := map[string]*nbgroup.Group{} - for id, group := range a.Groups { - groups[id] = group.Copy() - } - - policies := []*Policy{} - for _, policy := range a.Policies { - policies = append(policies, policy.Copy()) - } - - routes := map[route.ID]*route.Route{} - for id, r := range a.Routes { - routes[id] = r.Copy() - } - - nsGroups := map[string]*nbdns.NameServerGroup{} - for id, nsGroup := range a.NameServerGroups { - nsGroups[id] = nsGroup.Copy() - } - - dnsSettings := a.DNSSettings.Copy() - - var settings *Settings - if a.Settings != nil { - settings = a.Settings.Copy() - } - - postureChecks := []*posture.Checks{} - for _, postureCheck := range a.PostureChecks { - postureChecks = append(postureChecks, postureCheck.Copy()) - } - - return &Account{ - Id: a.Id, - CreatedBy: a.CreatedBy, - CreatedAt: a.CreatedAt, - Domain: a.Domain, - DomainCategory: a.DomainCategory, - IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, - SetupKeys: setupKeys, - Network: a.Network.Copy(), - Peers: peers, - Users: users, - Groups: groups, - Policies: policies, - Routes: routes, - NameServerGroups: nsGroups, - DNSSettings: dnsSettings, - PostureChecks: postureChecks, - Settings: settings, - } -} - -func (a *Account) GetGroupAll() (*nbgroup.Group, error) { - for _, g := range a.Groups { - if g.Name == "All" { - return g, nil - } - } - return nil, fmt.Errorf("no group ALL found") -} - -// GetPeer looks up a Peer by ID -func (a *Account) GetPeer(peerID string) *nbpeer.Peer { - return a.Peers[peerID] -} - // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - existedGroupsByName := make(map[string]*nbgroup.Group) +func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*types.Group, groupNames []string) (bool, []string, []*types.Group, error) { + existedGroupsByName := make(map[string]*types.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) - groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) + groupsToAdd := util.Difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := util.Difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return false, nil, nil, nil } - newGroupsToCreate := make([]*nbgroup.Group, 0) + newGroupsToCreate := make([]*types.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { - group = &nbgroup.Group{ + group = &types.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, - Issued: nbgroup.GroupIssuedJWT, + Issued: types.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } @@ -965,68 +237,16 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro return modified, newUserAutoGroups, newGroupsToCreate, nil } -// UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { - userPeers := make(map[string]struct{}) - for pid, peer := range a.Peers { - if peer.UserID == userID { - userPeers[pid] = struct{}{} - } - } - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok { - continue - } - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for pid := range userPeers { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - } -} - -// UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok || group.Name == "All" { - continue - } - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - peer, ok := a.Peers[pid] - if !ok { - continue - } - if peer.UserID != userID { - update = append(update, pid) - } - } - group.Peers = update - } -} - // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, - store Store, + store store.Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, - geo *geolocation.Geolocation, + geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, @@ -1122,7 +342,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1171,11 +391,27 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } + updateAccountPeers := false + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { + if newSettings.RoutingPeerDNSResolutionEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil) + } + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err } + err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, fmt.Errorf("groups propagation failed: %w", err) + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1183,24 +419,45 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + if updateAccountPeers { + go am.UpdateAccountPeers(ctx, accountID) + } + return updatedAccount, nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { - if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { - event := activity.AccountPeerInactivityExpirationEnabled - if !newSettings.PeerInactivityExpirationEnabled { - event = activity.AccountPeerInactivityExpirationDisabled - am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { + if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { + if newSettings.GroupsPropagationEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) + // Todo: retroactively add user groups to all peers } else { - am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) } - am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } - if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) + return nil +} + +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { + if newSettings.PeerInactivityExpirationEnabled { + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration + + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) + } + } else { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } } return nil @@ -1208,9 +465,12 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -1222,7 +482,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextPeerExpiration(ctx, accountID) @@ -1239,10 +499,13 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -1254,7 +517,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextInactivePeerExpiration(ctx, accountID) @@ -1271,7 +534,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() @@ -1351,7 +614,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Role != UserRoleOwner { + if user.Role != types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { @@ -1389,7 +652,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) + return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. @@ -1401,7 +664,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -1504,7 +767,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1515,10 +778,10 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s if user.IsServiceUser { continue } - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { continue } - users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) + users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) } log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID) userData, err := am.lookupCache(ctx, users, accountID) @@ -1534,7 +797,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) return nil, err @@ -1685,7 +948,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err @@ -1695,7 +958,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -1782,7 +1045,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, newUser := NewRegularUser(claims.UserId) newUser.AccountID = domainAccountID - err := am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser) + err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) if err != nil { return "", err } @@ -1834,22 +1097,22 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID) + return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) } // GetAccount returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } // GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token. -func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) { +func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { user, pat, err = am.extractPATFromToken(ctx, token) if err != nil { return nil, nil, "", "", err } - domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID) + domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) if err != nil { return nil, nil, "", "", err } @@ -1863,13 +1126,12 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token return nil, nil, fmt.Errorf("token has incorrect length") } - prefix := token[:len(PATPrefix)] - if prefix != PATPrefix { - return nil, nil, fmt.Errorf("token has incorrect prefix") + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { + return nil, nil, nil, fmt.Errorf("token has wrong prefix") } - - secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] - encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { @@ -1888,12 +1150,12 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token var pat *PersonalAccessToken err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - pat, err = transaction.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken) + pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) if err != nil { return err } - user, err = transaction.GetUserByPATID(ctx, LockingStrengthShare, pat.ID) + user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) return err }) if err != nil { @@ -1904,8 +1166,8 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token } // GetAccountByID returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -1926,7 +1188,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = PrivateCategory + claims.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } @@ -1935,7 +1197,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) @@ -1962,7 +1224,13 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { + if isToken, ok := claim.(bool); ok && isToken { + return nil + } + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -1988,14 +1256,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var addNewGroups []string var removeOldGroups []string var hasChanges bool - var user *User - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + var user *types.User + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2011,31 +1279,31 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } - addNewGroups = difference(updatedAutoGroups, user.AutoGroups) - removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -2045,11 +1313,11 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2067,7 +1335,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2080,7 +1348,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2105,7 +1373,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } } @@ -2138,7 +1406,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", errors.New(emptyUserID) } - if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { + if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } @@ -2155,7 +1423,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2176,7 +1444,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return am.addNewPrivateAccount(ctx, domainAccountID, claims) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2191,7 +1459,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests - domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2202,7 +1470,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2212,7 +1480,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err @@ -2223,7 +1491,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err @@ -2250,10 +1518,10 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain + return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) @@ -2280,7 +1548,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { - log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) + log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil @@ -2296,6 +1564,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer unlockPeer() + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) @@ -2332,7 +1603,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, return err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2354,7 +1625,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -2365,8 +1636,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2387,14 +1658,14 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := getPeerHostLabel(peerHostName, labelMap) + newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) if err != nil { return "", fmt.Errorf("failed to get new host label: %w", err) } @@ -2406,8 +1677,8 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } -func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -2416,70 +1687,70 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } - return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) } // addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { +func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() - defaultPolicy := &Policy{ + defaultPolicy := &types.Policy{ ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, Sources: []string{allGroup.ID}, Destinations: []string{allGroup.ID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } - account.Policies = []*Policy{defaultPolicy} + account.Policies = []*types.Policy{defaultPolicy} } return nil } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() + network := types.NewNetwork() peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) + users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} + setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := NewOwnerUser(userID) + owner := types.NewOwnerUser(userID) owner.AccountID = accountID users[userID] = owner - dnsSettings := DNSSettings{ + dnsSettings := types.DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ + acc := &types.Account{ Id: accountID, CreatedAt: time.Now().UTC(), SetupKeys: setupKeys, @@ -2491,14 +1762,15 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac Routes: routes, NameServerGroups: nameServersGroups, DNSSettings: dnsSettings, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, }, } @@ -2544,18 +1816,18 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID - allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + allGroupsMap := make(map[string]*types.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index 5f4897e6a..fa6c45856 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -7,6 +7,9 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // AccountRequest holds the result channel to return the requested account. @@ -17,19 +20,19 @@ type AccountRequest struct { // AccountResult holds the account data or an error. type AccountResult struct { - Account *Account + Account *types.Account Err error } type AccountRequestBuffer struct { - store Store + store store.Store getAccountRequests map[string][]*AccountRequest mu sync.Mutex getAccountRequestCh chan *AccountRequest bufferInterval time.Duration } -func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer { +func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") bufferInterval, err := time.ParseDuration(bufferIntervalStr) if err != nil { @@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu return &ac } -func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { +func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) { req := &AccountRequest{ AccountID: accountID, ResultChan: make(chan *AccountResult, 1), diff --git a/management/server/account_test.go b/management/server/account_test.go index 650e8de69..08bdc8821 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -6,70 +6,40 @@ import ( b64 "encoding/base64" "encoding/json" "fmt" + "io" "net" + "os" "reflect" + "strconv" "sync" "testing" "time" "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server/util" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) -type MocIntegratedValidator struct { - ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - if a.ValidatePeerFunc != nil { - return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) - } - return update, false, nil -} -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for _, peer := range peers { - validatedPeers[peer.ID] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) { -} - -func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { +func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", @@ -97,7 +67,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac } } -func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) { +func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) { t.Helper() if len(account.Peers) != 0 { t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) @@ -152,7 +122,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // peerID3 := "peer-3" tt := []struct { name string - accountSettings Settings + accountSettings types.Settings peerID string expectedPeers []string expectedOfflinePeers []string @@ -160,7 +130,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }{ { name: "Should return ALL peers when global peer login expiration disabled", - accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{peerID2}, expectedOfflinePeers: []string{}, @@ -177,7 +147,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), }, "peer-2": { ID: peerID2, @@ -191,14 +161,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: false, }, UserID: userID, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), LoginExpirationEnabled: true, }, }, }, { name: "Should return no peers when global peer login expiration enabled and peers expired", - accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{}, expectedOfflinePeers: []string{peerID2}, @@ -215,7 +185,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), LoginExpirationEnabled: true, }, "peer-2": { @@ -230,7 +200,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), LoginExpirationEnabled: true, }, }, @@ -392,12 +362,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { netIP := net.IP{100, 64, 0, 0} netMask := net.IPMask{255, 255, 0, 0} - network := &Network{ + network := &types.Network{ Identifier: "network", Net: net.IPNet{IP: netIP, Mask: netMask}, Dns: "netbird.selfhosted", Serial: 0, - mu: sync.Mutex{}, + Mu: sync.Mutex{}, } for _, testCase := range tt { @@ -416,7 +386,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -481,12 +451,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory + initUnknown.DomainCategory = types.UnknownCategory initUnknown.Domain = unknownDomain privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory + privateInitAccount.DomainCategory = types.PrivateCategory testCases := []struct { name string @@ -496,7 +466,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputUpdateClaimAccount bool testingFunc require.ComparisonAssertionFunc expectedMSG string - expectedUserRole UserRole + expectedUserRole types.UserRole expectedDomainCategory string expectedDomain string expectedPrimaryDomainStatus bool @@ -508,12 +478,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: publicDomain, UserId: "pub-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomainCategory: "", expectedDomain: publicDomain, expectedPrimaryDomainStatus: false, @@ -525,12 +495,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: unknownDomain, UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + DomainCategory: types.UnknownCategory, }, inputInitUserParams: initUnknown, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: unknownDomain, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -542,14 +512,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: "pvt-domain-user", expectedUsers: []string{"pvt-domain-user"}, @@ -559,15 +529,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateAttrs: true, inputInitUserParams: privateInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, + expectedUserRole: types.UserRoleUser, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, @@ -577,14 +547,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -594,15 +564,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateClaimAccount: true, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -612,12 +582,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: "", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: "", expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -729,7 +699,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*group.Group{} + groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -737,25 +707,29 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match") }) } func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", UserID: "someUser", @@ -763,7 +737,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -783,23 +757,27 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, - LastUsed: time.Time{}, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -817,7 +795,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { if err != nil { t.Fatalf("Error when getting account: %s", err) } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) + assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -901,7 +879,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { return } - exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") @@ -911,7 +889,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } } -func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { +func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) { account := newAccountWithId(context.Background(), accountID, userID, domain) err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -987,13 +965,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { claims := jwtclaims.AuthorizationClaims{ Domain: "example.com", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, } publicClaims := jwtclaims.AuthorizationClaims{ Domain: "test.com", UserId: "public-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, } am, err := createManager(b) @@ -1039,7 +1017,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } b.Run("public without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) if err != nil { @@ -1049,7 +1027,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { }) b.Run("private without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1060,7 +1038,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { b.Run("private with account ID", func(b *testing.B) { claims.AccountId = id - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1071,14 +1049,14 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } -func genUsers(p string, n int) map[string]*User { - users := map[string]*User{} +func genUsers(p string, n int) map[string]*types.User { + users := map[string]*types.User{} now := time.Now() for i := 0; i < n; i++ { - users[fmt.Sprintf("%s-%d", p, i)] = &User{ + users[fmt.Sprintf("%s-%d", p, i)] = &types.User{ Id: fmt.Sprintf("%s-%d", p, i), - Role: UserRoleAdmin, - LastLogin: now, + Role: types.UserRoleAdmin, + LastLogin: util.ToPtr(now), CreatedAt: now, Issued: "api", AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"}, @@ -1102,7 +1080,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1229,7 +1207,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{}, @@ -1239,15 +1217,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1306,7 +1284,7 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -1331,15 +1309,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1354,7 +1332,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, @@ -1364,15 +1342,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1410,7 +1388,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1423,15 +1401,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1478,7 +1456,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1553,7 +1531,7 @@ func TestGetUsersFromAccount(t *testing.T) { t.Fatal(err) } - users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} + users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} accountId := "test_account_id" account, err := createAccount(manager, accountId, users["1"].Id, "") @@ -1585,7 +1563,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1632,11 +1610,11 @@ func TestAccount_GetRoutesToSync(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1677,7 +1655,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1687,26 +1665,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } func TestAccount_Copy(t *testing.T) { - account := &Account{ + account := &types.Account{ Id: "account1", CreatedBy: "tester", CreatedAt: time.Now().UTC(), Domain: "test.com", DomainCategory: "public", IsDomainPrimaryAccount: true, - SetupKeys: map[string]*SetupKey{ + SetupKeys: map[string]*types.SetupKey{ "setup1": { Id: "setup1", AutoGroups: []string{"group1"}, }, }, - Network: &Network{ + Network: &types.Network{ Identifier: "net1", }, Peers: map[string]*nbpeer.Peer{ @@ -1719,35 +1697,36 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Users: map[string]*User{ + Users: map[string]*types.User{ "user1": { Id: "user1", - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, AutoGroups: []string{"group1"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: "user1", CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "policy1", Enabled: true, - Rules: make([]*PolicyRule, 0), + Rules: make([]*types.PolicyRule, 0), SourcePostureChecks: make([]string, 0), }, }, @@ -1767,13 +1746,36 @@ func TestAccount_Copy(t *testing.T) { NameServers: []nbdns.NameServer{}, }, }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, PostureChecks: []*posture.Checks{ { ID: "posture Checks1", }, }, - Settings: &Settings{}, + Settings: &types.Settings{}, + Networks: []*networkTypes.Network{ + { + ID: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 0, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "resource", + Type: "Subnet", + Address: "172.12.6.1/24", + }, + }, } err := hasNilField(account) if err != nil { @@ -1826,7 +1828,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.NotNil(t, settings) @@ -1856,7 +1858,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1904,7 +1906,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1970,7 +1972,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1983,7 +1985,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2001,7 +2003,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2009,19 +2011,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2058,7 +2060,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-30 * time.Minute), + LastLogin: util.ToPtr(time.Now().UTC().Add(-30 * time.Minute)), UserID: userID, }, "peer-2": { @@ -2069,7 +2071,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-2 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-2 * time.Hour)), UserID: userID, }, @@ -2081,7 +2083,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-1 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-1 * time.Hour)), UserID: userID, }, }, @@ -2094,9 +2096,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -2143,7 +2145,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: false, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-30 * time.Minute), + LastLogin: util.ToPtr(time.Now().UTC().Add(-30 * time.Minute)), UserID: userID, }, "peer-2": { @@ -2154,7 +2156,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: false, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-2 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-2 * time.Hour)), UserID: userID, }, "peer-3": { @@ -2165,7 +2167,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-1 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-1 * time.Hour)), UserID: userID, }, }, @@ -2178,9 +2180,9 @@ func TestAccount_GetInactivePeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -2245,7 +2247,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2314,7 +2316,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2435,7 +2437,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), UserID: userID, }, "peer-2": { @@ -2478,9 +2480,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextPeerExpiration() @@ -2595,7 +2597,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LastSeen: time.Now().Add(-1 * time.Second), }, InactivityExpirationEnabled: true, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), UserID: userID, }, "peer-2": { @@ -2638,9 +2640,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextInactivePeerExpiration() @@ -2659,7 +2661,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { require.NoError(t, err, "unable to create account manager") // create a new account - account := &Account{ + account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, @@ -2668,18 +2670,31 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, - Users: map[string]*User{ - "user1": {Id: "user1", AccountID: "accountID"}, - "user2": {Id: "user2", AccountID: "accountID"}, + Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, + Users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: "accountID", CreatedAt: time.Now()}, + "user2": {Id: "user2", AccountID: "accountID", CreatedAt: time.Now()}, }, } assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("skip sync for token auth type", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") + }) + t.Run("empty jwt groups", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2688,7 +2703,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) @@ -2701,18 +2716,18 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} - assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) + assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2721,13 +2736,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { @@ -2738,7 +2753,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2751,7 +2766,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2764,16 +2779,16 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") }) - t.Run("remove all JWT groups", func(t *testing.T) { + t.Run("remove all JWT groups when list is empty", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", Raw: jwt.MapClaims{"groups": []interface{}{}}, @@ -2781,15 +2796,28 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") + assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") + }) + + t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") }) } func TestAccount_UserGroupsAddToPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2797,12 +2825,12 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("add groups", func(t *testing.T) { @@ -2825,7 +2853,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { } func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2833,12 +2861,12 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("remove groups", func(t *testing.T) { @@ -2881,10 +2909,10 @@ func createManager(t TB) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t TB) (Store, error) { +func createStore(t TB) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -2907,7 +2935,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() manager, err := createManager(t) @@ -2920,12 +2948,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") } - getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer { key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -2976,3 +3004,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) t.Error("Timed out waiting for update message") } } + +func BenchmarkSyncAndMarkPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 + }{ + {"Small", 50, 5, 1, 3, 3, 19}, + {"Medium", 500, 100, 7, 13, 10, 90}, + {"Large", 5000, 200, 65, 80, 60, 240}, + {"Small single", 50, 10, 1, 3, 3, 80}, + {"Medium single", 500, 10, 7, 13, 10, 37}, + {"Large 5", 5000, 15, 65, 80, 60, 220}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + } + }) + } +} + +func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 + }{ + {"Small", 50, 5, 102, 110, 3, 20}, + {"Medium", 500, 100, 105, 140, 20, 110}, + {"Large", 5000, 200, 160, 200, 120, 260}, + {"Small single", 50, 10, 102, 110, 5, 40}, + {"Medium single", 500, 10, 105, 140, 10, 60}, + {"Large 5", 5000, 15, 160, 200, 60, 180}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: account.Peers["peer-1"].Key, + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + } + }) + } +} + +func BenchmarkLoginPeer_NewPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 + }{ + {"Small", 50, 5, 107, 120, 10, 80}, + {"Medium", 500, 100, 105, 140, 30, 140}, + {"Large", 5000, 200, 180, 220, 140, 300}, + {"Small single", 50, 10, 107, 120, 10, 80}, + {"Medium single", 500, 10, 105, 140, 20, 60}, + {"Large 5", 5000, 15, 180, 220, 80, 200}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: "some-new-key" + strconv.Itoa(i), + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + } + }) + } +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 603260dbc..5379a8dd8 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -148,6 +148,27 @@ const ( AccountPeerInactivityExpirationDurationUpdated Activity = 67 SetupKeyDeleted Activity = 68 + + UserGroupPropagationEnabled Activity = 69 + UserGroupPropagationDisabled Activity = 70 + + AccountRoutingPeerDNSResolutionEnabled Activity = 71 + AccountRoutingPeerDNSResolutionDisabled Activity = 72 + + NetworkCreated Activity = 73 + NetworkUpdated Activity = 74 + NetworkDeleted Activity = 75 + + NetworkResourceCreated Activity = 76 + NetworkResourceUpdated Activity = 77 + NetworkResourceDeleted Activity = 78 + + NetworkRouterCreated Activity = 79 + NetworkRouterUpdated Activity = 80 + NetworkRouterDeleted Activity = 81 + + ResourceAddedToGroup Activity = 82 + ResourceRemovedFromGroup Activity = 83 ) var activityMap = map[Activity]Code{ @@ -222,6 +243,27 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, + + UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, + UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, + + AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"}, + AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"}, + + NetworkCreated: {"Network created", "network.create"}, + NetworkUpdated: {"Network updated", "network.update"}, + NetworkDeleted: {"Network deleted", "network.delete"}, + + NetworkResourceCreated: {"Network resource created", "network.resource.create"}, + NetworkResourceUpdated: {"Network resource updated", "network.resource.update"}, + NetworkResourceDeleted: {"Network resource deleted", "network.resource.delete"}, + + NetworkRouterCreated: {"Network router created", "network.router.create"}, + NetworkRouterUpdated: {"Network router updated", "network.router.update"}, + NetworkRouterDeleted: {"Network router deleted", "network.router.delete"}, + + ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, + ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/config.go b/management/server/config.go index 2f7e49766..f3555b92b 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -156,7 +157,7 @@ type ProviderConfig struct { // StoreConfig contains Store configuration type StoreConfig struct { - Engine StoreEngine + Engine store.Engine } // ReverseProxy contains reverse proxy configuration in front of management. diff --git a/management/server/dns.go b/management/server/dns.go index be7caea4e..39dc11eb2 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "slices" - "strconv" "sync" log "github.com/sirupsen/logrus" @@ -12,12 +10,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const defaultTTL = 300 - // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { CustomZones sync.Map @@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG c.NameServerGroups.Store(key, value) } -type lookupMap map[string]struct{} - -// DNSSettings defines dns settings at the account level -type DNSSettings struct { - // DisabledManagementGroups groups whose DNS management is disabled - DisabledManagementGroups []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the DNS settings -func (d DNSSettings) Copy() DNSSettings { - settings := DNSSettings{ - DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), - } - copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) - return settings -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { return err } - oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) if err != nil { @@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -155,18 +136,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } -// prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { +// prepareDNSSettingsEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) return nil @@ -203,8 +184,8 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t } // areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. -func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups) if err != nil { return false, err } @@ -213,16 +194,16 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc return true, nil } - return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) + return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups) } // validateDNSSettings validates the DNS settings. -func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { +func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error { if len(settings.DisabledManagementGroups) == 0 { return nil } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) if err != nil { return err } @@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } - -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.getPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - -func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { - for _, peer := range account.Peers { - label, err := getPeerHostLabel(peer.Name, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) - label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) - continue - } - } - peer.DNSLabel = label - peerLabels[label] = struct{}{} - } -} - -func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) { - label, err := nbdns.GetParsedDomainLabel(name) - if err != nil { - return "", err - } - - uniqueLabel := getUniqueHostLabel(label, peerLabels) - if uniqueLabel == "" { - return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) - } - return uniqueLabel, nil -} - -// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 -func getUniqueHostLabel(name string, peerLabels lookupMap) string { - _, found := peerLabels[name] - if !found { - return name - } - for i := 1; i < 1000; i++ { - nameWithSuffix := name + "-" + strconv.Itoa(i) - _, found = peerLabels[nameWithSuffix] - if !found { - return nameWithSuffix - } - } - return "" -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..6fb9f6a29 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -53,7 +54,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + account.DNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{group1ID}, } @@ -86,20 +87,20 @@ func TestSaveDNSSettings(t *testing.T) { testCases := []struct { name string userID string - inputSettings *DNSSettings + inputSettings *types.DNSSettings shouldFail bool }{ { name: "Saving As Admin Should Be OK", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, }, { name: "Should Not Update Settings As Regular User", userID: dnsRegularUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, shouldFail: true, @@ -113,7 +114,7 @@ func TestSaveDNSSettings(t *testing.T) { { name: "Should Not Update Settings If Group Is Invalid", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{"non-existing-group"}, }, shouldFail: true, @@ -210,10 +211,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createDNSStore(t *testing.T) (Store, error) { +func createDNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -222,7 +223,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -259,9 +260,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - account.Users[dnsRegularUserID] = &User{ + account.Users[dnsRegularUserID] = &types.User{ Id: dnsRegularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } err := am.Store.SaveAccount(context.Background(), account) @@ -293,13 +294,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &group.Group{ + newGroup1 := &types.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &group.Group{ + newGroup2 := &types.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } @@ -483,7 +484,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -510,7 +511,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -550,7 +551,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -589,7 +590,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA", "groupB"}, }) assert.NoError(t, err) @@ -609,7 +610,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -629,7 +630,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{}, }) assert.NoError(t, err) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 6e245ec5a..3d6d01434 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" ) const ( @@ -32,7 +33,7 @@ type ephemeralPeer struct { // EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store Store + store store.Store accountManager AccountManager headPeer *ephemeralPeer @@ -42,7 +43,7 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { +func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager { return &EphemeralManager{ store: store, accountManager: accountManager, @@ -120,22 +121,18 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) if err != nil { log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) return } t := newDeadLine() - count := 0 for _, p := range peers { - if p.Ephemeral { - count++ - e.addPeer(p.AccountID, p.ID, t) - } + e.addPeer(p.AccountID, p.ID, t) } - log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } func (e *EphemeralManager) cleanup(ctx context.Context) { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 00e5d777a..df8fe98c3 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,14 +7,16 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) type MockStore struct { - Store - account *Account + store.Store + account *types.Account } -func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer for _, v := range s.account.Peers { if v.Ephemeral { diff --git a/management/server/event.go b/management/server/event.go index 93b809226..788d1b51c 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "os" "time" log "github.com/sirupsen/logrus" @@ -11,6 +12,11 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +func isEnabled() bool { + response := os.Getenv("NB_EVENT_ACTIVITY_LOG_ENABLED") + return response == "" || response == "true" +} + // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -56,20 +62,20 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI } func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { - - go func() { - _, err := am.eventStore.Save(ctx, &activity.Event{ - Timestamp: time.Now().UTC(), - Activity: activityID, - InitiatorID: initiatorID, - TargetID: targetID, - AccountID: accountID, - Meta: meta, - }) - if err != nil { - // todo add metric - log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) - } - }() - + if isEnabled() { + go func() { + _, err := am.eventStore.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activityID, + InitiatorID: initiatorID, + TargetID: targetID, + AccountID: accountID, + Meta: meta, + }) + if err != nil { + // todo add metric + log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) + } + }() + } } diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 553a31581..c0179a1c4 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -14,7 +14,14 @@ import ( log "github.com/sirupsen/logrus" ) -type Geolocation struct { +type Geolocation interface { + Lookup(ip net.IP) (*Record, error) + GetAllCountries() ([]Country, error) + GetCitiesByCountry(countryISOCode string) ([]City, error) + Stop() error +} + +type geolocationImpl struct { mmdbPath string mux sync.RWMutex db *maxminddb.Reader @@ -54,7 +61,7 @@ const ( geonamesdbPattern = "geonames_*.db" ) -func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) { mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) if err != nil { @@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol return nil, err } - geo := &Geolocation{ + geo := &geolocationImpl{ mmdbPath: mmdbPath, mux: sync.RWMutex{}, db: db, @@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) { return db, nil } -func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { +func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() @@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { } // GetAllCountries retrieves a list of all countries. -func (gl *Geolocation) GetAllCountries() ([]Country, error) { +func (gl *geolocationImpl) GetAllCountries() ([]Country, error) { allCountries, err := gl.locationDB.GetAllCountries() if err != nil { return nil, err @@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) { } // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. -func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { +func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) { allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) if err != nil { return nil, err @@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) return cities, nil } -func (gl *Geolocation) Stop() error { +func (gl *geolocationImpl) Stop() error { close(gl.stopCh) if gl.db != nil { if err := gl.db.Close(); err != nil { @@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin } return nil } + +type Mock struct{} + +func (g *Mock) Lookup(ip net.IP) (*Record, error) { + return &Record{}, nil +} + +func (g *Mock) GetAllCountries() ([]Country, error) { + return []Country{}, nil +} + +func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) { + return []City{}, nil +} + +func (g *Mock) Stop() error { + return nil +} diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go index 9bdefd268..fecd715be 100644 --- a/management/server/geolocation/geolocation_test.go +++ b/management/server/geolocation/geolocation_test.go @@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) { db, err := openDB(filename) assert.NoError(t, err) - geo := &Geolocation{ + geo := &geolocationImpl{ mux: sync.RWMutex{}, db: db, stopCh: make(chan struct{}), diff --git a/management/server/group.go b/management/server/group.go index 758b28b76..f1057dda6 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -6,13 +6,16 @@ import ( "fmt" "slices" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -27,12 +30,7 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -41,7 +39,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return status.NewUserNotPartOfAccountError() } - if user.IsRegularUser() && settings.RegularUsersViewBlocked { + if user.IsRegularUser() { return status.NewAdminPermissionError() } @@ -49,38 +47,38 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco } // GetGroup returns a specific group by groupID in an account -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { + return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -94,10 +92,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } var eventsToStore []func() - var groupsToSave []*nbgroup.Group + var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -117,11 +115,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -132,23 +130,23 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { - addedPeers = difference(newGroup.Peers, oldGroup.Peers) - removedPeers = difference(oldGroup.Peers, newGroup.Peers) + addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) + removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -157,7 +155,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac } modifiedPeers := slices.Concat(addedPeers, removedPeers) - peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) if err != nil { log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) return nil @@ -198,65 +196,11 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return eventsToStore } -// difference returns the elements in `a` that aren't in `b`. -func difference(a, b []string) []string { - mb := make(map[string]struct{}, len(b)) - for _, x := range b { - mb[x] = struct{}{} - } - var diff []string - for _, x := range a { - if _, found := mb[x]; !found { - diff = append(diff, x) - } - } - return diff -} - // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() - } - - if user.IsRegularUser() { - return status.NewAdminPermissionError() - } - - var group *nbgroup.Group - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { - return err - } - - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return err - } - - return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) - }) - if err != nil { - return err - } - - am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - - return nil + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -266,7 +210,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -281,17 +225,18 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var allErrors error var groupIDsToDelete []string - var deletedGroups []*nbgroup.Group + var deletedGroups []*types.Group - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID) if err != nil { + allErrors = errors.Join(allErrors, err) continue } if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + allErrors = errors.Join(allErrors, err) continue } @@ -299,11 +244,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -318,12 +263,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - var group *nbgroup.Group + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -337,18 +285,59 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupAddResource appends resource to the group +func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -356,12 +345,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - var group *nbgroup.Group + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -375,31 +367,72 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupDeleteResource removes resource from the group +func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemoveResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) } return nil } // validateNewGroup validates the new group for existence and required fields. -func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { +func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *types.Group) error { + if newGroup.ID == "" && newGroup.Issued != types.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err @@ -416,7 +449,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, } for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } @@ -425,18 +458,26 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, return nil } -func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if group.Issued == types.GroupIssuedIntegration { + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { - return status.Errorf(status.NotFound, "user not found") + return err } - if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if len(group.Resources) > 0 { + return &GroupLinkError{"network resource", group.Resources[0].ID} + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } @@ -461,8 +502,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. } // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -471,7 +512,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -484,15 +525,18 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil } for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + isLinked := slices.Contains(r.Groups, groupID) || + slices.Contains(r.PeerGroups, groupID) || + slices.Contains(r.AccessControlGroups, groupID) + if isLinked { return true, r } } @@ -501,8 +545,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -519,8 +563,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -538,8 +582,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -554,8 +598,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { - users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -570,12 +614,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } @@ -598,7 +642,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -607,15 +651,15 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s return false } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) +// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. +func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) if err != nil { return false, err } for _, group := range groups { - if group.HasPeers() { + if group.HasPeers() || group.HasResources() { return true, nil } } diff --git a/management/server/group_test.go b/management/server/group_test.go index 0515b9698..cc90f187b 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -32,22 +32,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedIntegration + group.Issued = types.GroupIssuedIntegration err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedJWT + group.Issued = types.GroupIssuedJWT err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) + t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI group.ID = "" err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { @@ -145,13 +145,13 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { manager, account, err := initTestGroupAccount(am) assert.NoError(t, err, "Failed to init testing account") - groups := make([]*nbgroup.Group, 10) + groups := make([]*types.Group, 10) for i := 0; i < 10; i++ { - groups[i] = &nbgroup.Group{ + groups[i] = &types.Group{ ID: fmt.Sprintf("group-%d", i+1), AccountID: account.Id, Name: fmt.Sprintf("group-%d", i+1), - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } } @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", @@ -267,63 +267,63 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } -func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) { +func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &nbgroup.Group{ + groupForRoute := &types.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForRoute2 := &nbgroup.Group{ + groupForRoute2 := &types.Group{ ID: "grp-for-route2", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &nbgroup.Group{ + groupForNameServerGroups := &types.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &nbgroup.Group{ + groupForPolicies := &types.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &nbgroup.Group{ + groupForSetupKeys := &types.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &nbgroup.Group{ + groupForUsers := &types.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &nbgroup.Group{ + groupForIntegration := &types.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: nbgroup.GroupIssuedIntegration, + Issued: types.GroupIssuedIntegration, Peers: make([]string, 0), } @@ -342,9 +342,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A Groups: []string{groupForNameServerGroups.ID}, } - policy := &Policy{ + policy := &types.Policy{ ID: "example policy", - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "example policy rule", Destinations: []string{groupForPolicies.ID}, @@ -352,12 +352,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A }, } - setupKey := &SetupKey{ + setupKey := &types.SetupKey{ Id: "example setup key", AutoGroups: []string{groupForSetupKeys.ID}, + UpdatedAt: time.Now(), } - user := &User{ + user := &types.User{ Id: "example user", AutoGroups: []string{groupForUsers.ID}, } @@ -392,7 +393,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -429,7 +430,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, @@ -500,15 +501,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -522,7 +523,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -591,7 +592,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, @@ -632,7 +633,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -648,7 +649,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to dns settings should update account peers and send peer update t.Run("saving group linked to dns settings", func(t *testing.T) { - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupD"}, }) assert.NoError(t, err) @@ -659,7 +660,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go new file mode 100644 index 000000000..02b669e41 --- /dev/null +++ b/management/server/groups/manager.go @@ -0,0 +1,210 @@ +package groups + +import ( + "context" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) + AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error + AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) + RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) + if err != nil { + return nil, err + } + if !ok { + return nil, err + } + + groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account groups: %w", err) + } + + return groups, nil +} + +func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + groups, err := m.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, err + } + + groupsMap := make(map[string]*types.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + +func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Write) + if err != nil { + return err + } + if !ok { + return err + } + + event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource) + if err != nil { + return fmt.Errorf("error adding resource to group: %w", err) + } + + event() + + return nil +} + +func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resource *types.Resource) (func(), error) { + err := transaction.AddResourceToGroup(ctx, accountID, groupID, resource) + if err != nil { + return nil, fmt.Errorf("error adding resource to group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceAddedToGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + err := transaction.RemoveResourceFromGroup(ctx, accountID, groupID, resourceID) + if err != nil { + return nil, fmt.Errorf("error removing resource from group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceRemovedFromGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) +} + +func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum { + groupsInfo := []api.GroupMinimum{} + groupsChecked := make(map[string]struct{}) + for _, group := range groups { + _, ok := groupsChecked[group.ID] + if ok { + continue + } + groupsChecked[group.ID] = struct{}{} + for _, pk := range group.Peers { + if pk == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + for _, rk := range group.Resources { + if rk.ID == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + } + return groupsInfo +} + +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + return map[string]*types.Group{}, nil +} + +func (m *mockManager) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error { + return nil +} + +func (m *mockManager) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) { + return func() { + // noop + }, nil +} + +func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + return func() { + // noop + }, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 9c12336f8..daa23d2ab 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -23,19 +23,22 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { - accountManager AccountManager - wgKey wgtypes.Key + accountManager AccountManager + settingsManager settings.Manager + wgKey wgtypes.Key proto.UnimplementedManagementServiceServer peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator *jwtclaims.JWTValidator + jwtValidator jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager @@ -47,6 +50,7 @@ func NewServer( ctx context.Context, config *Config, accountManager AccountManager, + settingsManager settings.Manager, peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, @@ -57,7 +61,7 @@ func NewServer( return nil, err } - var jwtValidator *jwtclaims.JWTValidator + var jwtValidator jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( @@ -99,6 +103,7 @@ func NewServer( // peerKey -> event channel peersUpdateManager: peersUpdateManager, accountManager: accountManager, + settingsManager: settingsManager, config: config, secretsManager: secretsManager, jwtValidator: jwtValidator, @@ -483,7 +488,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) @@ -599,20 +604,21 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok } } -func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: dnsResolutionOnRoutingPeerEnabled, } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnbled bool) *proto.SyncResponse { response := &proto.SyncResponse{ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials), - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnbled), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -661,7 +667,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { var err error var turnToken *Token @@ -680,7 +686,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p } } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, peer.UserID) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9b4592ccf..f53092415 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -84,6 +84,10 @@ components: items: type: string example: Administrators + routing_peer_dns_resolution_enabled: + description: Enables or disables DNS resolution on the routing peers + type: boolean + example: true extra: $ref: '#/components/schemas/AccountExtraSettings' required: @@ -439,17 +443,13 @@ components: example: 5 required: - accessible_peers_count - SetupKey: + SetupKeyBase: type: object properties: id: description: Setup Key ID type: string example: 2531583362 - key: - description: Setup Key value - type: string - example: A616097E-FCF0-48FA-9354-CA4A61142761 name: description: Setup key name identifier type: string @@ -518,22 +518,31 @@ components: - updated_at - usage_limit - ephemeral + SetupKeyClear: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as plain text + type: string + example: A616097E-FCF0-48FA-9354-CA4A61142761 + required: + - key + SetupKey: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as secret + type: string + example: A6160**** + required: + - key SetupKeyRequest: type: object properties: - name: - description: Setup Key name - type: string - example: Default key - type: - description: Setup key type, one-off for single time usage and reusable - type: string - example: reusable - expires_in: - description: Expiration time in seconds, 0 will mean the key never expires - type: integer - minimum: 0 - example: 86400 revoked: description: Setup key revocation status type: boolean @@ -544,21 +553,9 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m0" - usage_limit: - description: A number of times this key can be used. The value of 0 indicates the unlimited usage. - type: integer - example: 0 - ephemeral: - description: Indicate that the peer will be ephemeral or not - type: boolean - example: true required: - - name - - type - - expires_in - revoked - auto_groups - - usage_limit CreateSetupKeyRequest: type: object properties: @@ -675,6 +672,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + resources_count: + description: Count of resources associated to the group + type: integer + example: 5 issued: description: How the group was issued (api, integration, jwt) type: string @@ -684,6 +685,7 @@ components: - id - name - peers_count + - resources_count GroupRequest: type: object properties: @@ -697,6 +699,10 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m1" + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - name Group: @@ -709,15 +715,16 @@ components: type: array items: $ref: '#/components/schemas/PeerMinimum' + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - peers + - resources PolicyRuleMinimum: type: object properties: - id: - description: Policy rule ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy rule name identifier type: string @@ -783,46 +790,80 @@ components: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg sources: description: Policy rule source group IDs type: array items: type: string example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: type: string example: "ch8i4ug6lnn4g9h7v7m0" - required: - - sources - - destinations - PolicyRule: + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' + + PolicyRuleCreate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object properties: + sources: + description: Policy rule source group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' + destinations: + description: Policy rule destination group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9h7v7m0" + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' + PolicyRule: + allOf: + - $ref: '#/components/schemas/PolicyRuleMinimum' + - type: object + properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg sources: description: Policy rule source group IDs type: array items: $ref: '#/components/schemas/GroupMinimum' + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: $ref: '#/components/schemas/GroupMinimum' - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyMinimum: type: object properties: - id: - description: Policy ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy name identifier type: string @@ -837,7 +878,6 @@ components: example: true required: - name - - description - enabled PolicyUpdate: allOf: @@ -857,11 +897,33 @@ components: $ref: '#/components/schemas/PolicyRuleUpdate' required: - rules + PolicyCreate: + allOf: + - $ref: '#/components/schemas/PolicyMinimum' + - type: object + properties: + source_posture_checks: + description: Posture checks ID's applied to policy source groups + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + rules: + description: Policy rule object for policy UI editor + type: array + items: + $ref: '#/components/schemas/PolicyRuleUpdate' + required: + - rules Policy: allOf: - $ref: '#/components/schemas/PolicyMinimum' - type: object properties: + id: + description: Policy ID + type: string + example: ch8i4ug6lnn4g9hqv7mg source_posture_checks: description: Posture checks ID's applied to policy source groups type: array @@ -1183,6 +1245,181 @@ components: - id - network_type - $ref: '#/components/schemas/RouteRequest' + Resource: + type: object + properties: + id: + description: ID of the resource + type: string + example: chacdk86lnnboviihd7g + type: + description: Type of the resource + $ref: '#/components/schemas/ResourceType' + required: + - id + - type + ResourceType: + allOf: + - $ref: '#/components/schemas/NetworkResourceType' + - type: string + example: host + NetworkRequest: + type: object + properties: + name: + description: Network name + type: string + example: Remote Network 1 + description: + description: Network description + type: string + example: A remote network that needs to be accessed + required: + - name + Network: + allOf: + - type: object + properties: + id: + description: Network ID + type: string + example: chacdk86lnnboviihd7g + routers: + description: List of router IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + routing_peers_count: + description: Count of routing peers associated with the network + type: integer + example: 2 + resources: + description: List of network resource IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m1 + policies: + description: List of policy IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m2 + required: + - id + - routers + - resources + - routing_peers_count + - policies + - $ref: '#/components/schemas/NetworkRequest' + NetworkResourceMinimum: + type: object + properties: + name: + description: Network resource name + type: string + example: Remote Resource 1 + description: + description: Network resource description + type: string + example: A remote resource inside network 1 + address: + description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + type: string + example: "1.1.1.1" + enabled: + description: Network resource status + type: boolean + example: true + required: + - name + - address + - enabled + NetworkResourceRequest: + allOf: + - $ref: '#/components/schemas/NetworkResourceMinimum' + - type: object + properties: + groups: + description: Group IDs containing the resource + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + required: + - groups + - address + NetworkResource: + allOf: + - type: object + properties: + id: + description: Network Resource ID + type: string + example: chacdk86lnnboviihd7g + type: + $ref: '#/components/schemas/NetworkResourceType' + groups: + description: Groups that the resource belongs to + type: array + items: + $ref: '#/components/schemas/GroupMinimum' + required: + - id + - type + - groups + - $ref: '#/components/schemas/NetworkResourceMinimum' + NetworkResourceType: + description: Network resource type based of the address + type: string + enum: [ "host", "subnet", "domain" ] + example: host + NetworkRouterRequest: + type: object + properties: + peer: + description: Peer Identifier associated with route. This property can not be set together with `peer_groups` + type: string + example: chacbco6lnnbn6cg5s91 + peer_groups: + description: Peers Group Identifier associated with route. This property can not be set together with `peer` + type: array + items: + type: string + example: chacbco6lnnbn6cg5s91 + metric: + description: Route metric number. Lowest number has higher priority + type: integer + maximum: 9999 + minimum: 1 + example: 9999 + masquerade: + description: Indicate if peer should masquerade traffic to this route's prefix + type: boolean + example: true + enabled: + description: Network router status + type: boolean + example: true + required: + # Only one property has to be set + #- peer + #- peer_groups + - metric + - masquerade + - enabled + NetworkRouter: + allOf: + - type: object + properties: + id: + description: Network Router Id + type: string + example: chacdk86lnnboviihd7g + required: + - id + - $ref: '#/components/schemas/NetworkRouterRequest' Nameserver: type: object properties: @@ -1943,7 +2180,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SetupKey' + $ref: '#/components/schemas/SetupKeyClear' '400': "$ref": "#/components/responses/bad_request" '401': @@ -2281,7 +2518,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/PolicyUpdate' + $ref: '#/components/schemas/PolicyCreate' responses: '200': description: A Policy object @@ -2467,6 +2704,502 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/networks: + get: + summary: List all Networks + description: Returns a list of all networks + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Networks + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network + description: Creates a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New Network request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network Object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}: + get: + summary: Retrieve a Network + description: Get information about a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network + description: Update/Replace a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: Update Network request + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network + description: Delete a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources: + get: + summary: List all Network Resources + description: Returns a list of all resources in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Resources + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Resource + description: Creates a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources/{resourceId}: + get: + summary: Retrieve a Network Resource + description: Get information about a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Resource + description: Update a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a resource + requestBody: + description: Update Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Resource + description: Delete a network resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers: + get: + summary: List all Network Routers + description: Returns a list of all routers in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Routers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Router + description: Creates a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers/{routerId}: + get: + summary: Retrieve a Network Router + description: Get information about a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Router + description: Update a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + requestBody: + description: Update Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Router + description: Delete a network router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/dns/nameservers: get: summary: List all Nameserver Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c1ef1ba21..943d1b327 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -88,6 +88,13 @@ const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Defines values for NetworkResourceType. +const ( + NetworkResourceTypeDomain NetworkResourceType = "domain" + NetworkResourceTypeHost NetworkResourceType = "host" + NetworkResourceTypeSubnet NetworkResourceType = "subnet" +) + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" @@ -136,6 +143,13 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Defines values for ResourceType. +const ( + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -234,6 +248,9 @@ type AccountSettings struct { // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` + + // RoutingPeerDnsResolutionEnabled Enables or disables DNS resolution on the routing peers + RoutingPeerDnsResolutionEnabled *bool `json:"routing_peer_dns_resolution_enabled,omitempty"` } // Checks List of objects that perform the actual checks @@ -365,7 +382,11 @@ type Group struct { Peers []PeerMinimum `json:"peers"` // PeersCount Count of peers associated to the group - PeersCount int `json:"peers_count"` + PeersCount int `json:"peers_count"` + Resources []Resource `json:"resources"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupIssued How the group was issued (api, integration, jwt) @@ -384,6 +405,9 @@ type GroupMinimum struct { // PeersCount Count of peers associated to the group PeersCount int `json:"peers_count"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupMinimumIssued How the group was issued (api, integration, jwt) @@ -395,7 +419,8 @@ type GroupRequest struct { Name string `json:"name"` // Peers List of peers ids - Peers *[]string `json:"peers,omitempty"` + Peers *[]string `json:"peers,omitempty"` + Resources *[]Resource `json:"resources,omitempty"` } // Location Describe geographical location information @@ -494,6 +519,138 @@ type NameserverGroupRequest struct { SearchDomainsEnabled bool `json:"search_domains_enabled"` } +// Network defines model for Network. +type Network struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Id Network ID + Id string `json:"id"` + + // Name Network name + Name string `json:"name"` + + // Policies List of policy IDs associated with the network + Policies []string `json:"policies"` + + // Resources List of network resource IDs associated with the network + Resources []string `json:"resources"` + + // Routers List of router IDs associated with the network + Routers []string `json:"routers"` + + // RoutingPeersCount Count of routing peers associated with the network + RoutingPeersCount int `json:"routing_peers_count"` +} + +// NetworkRequest defines model for NetworkRequest. +type NetworkRequest struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Name Network name + Name string `json:"name"` +} + +// NetworkResource defines model for NetworkResource. +type NetworkResource struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Enabled Network resource status + Enabled bool `json:"enabled"` + + // Groups Groups that the resource belongs to + Groups []GroupMinimum `json:"groups"` + + // Id Network Resource ID + Id string `json:"id"` + + // Name Network resource name + Name string `json:"name"` + + // Type Network resource type based of the address + Type NetworkResourceType `json:"type"` +} + +// NetworkResourceMinimum defines model for NetworkResourceMinimum. +type NetworkResourceMinimum struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Enabled Network resource status + Enabled bool `json:"enabled"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceRequest defines model for NetworkResourceRequest. +type NetworkResourceRequest struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Enabled Network resource status + Enabled bool `json:"enabled"` + + // Groups Group IDs containing the resource + Groups []string `json:"groups"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceType Network resource type based of the address +type NetworkResourceType string + +// NetworkRouter defines model for NetworkRouter. +type NetworkRouter struct { + // Enabled Network router status + Enabled bool `json:"enabled"` + + // Id Network Router Id + Id string `json:"id"` + + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + +// NetworkRouterRequest defines model for NetworkRouterRequest. +type NetworkRouterRequest struct { + // Enabled Network router status + Enabled bool `json:"enabled"` + + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -737,7 +894,7 @@ type PersonalAccessTokenRequest struct { // Policy defines model for Policy. type Policy struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` @@ -755,16 +912,31 @@ type Policy struct { SourcePostureChecks []string `json:"source_posture_checks"` } -// PolicyMinimum defines model for PolicyMinimum. -type PolicyMinimum struct { +// PolicyCreate defines model for PolicyCreate. +type PolicyCreate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` + // Name Policy name identifier + Name string `json:"name"` + + // Rules Policy rule object for policy UI editor + Rules []PolicyRuleUpdate `json:"rules"` + + // SourcePostureChecks Posture checks ID's applied to policy source groups + SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"` +} + +// PolicyMinimum defines model for PolicyMinimum. +type PolicyMinimum struct { + // Description Policy friendly description + Description *string `json:"description,omitempty"` + + // Enabled Policy status + Enabled bool `json:"enabled"` // Name Policy name identifier Name string `json:"name"` @@ -779,10 +951,11 @@ type PolicyRule struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []GroupMinimum `json:"destinations"` + Destinations *[]GroupMinimum `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -800,10 +973,11 @@ type PolicyRule struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleProtocol `json:"protocol"` + Protocol PolicyRuleProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []GroupMinimum `json:"sources"` + Sources *[]GroupMinimum `json:"sources,omitempty"` } // PolicyRuleAction Policy rule accept or drops packets @@ -826,9 +1000,6 @@ type PolicyRuleMinimum struct { // Enabled Policy rule status Enabled bool `json:"enabled"` - // Id Policy rule ID - Id *string `json:"id,omitempty"` - // Name Policy rule name identifier Name string `json:"name"` @@ -857,10 +1028,11 @@ type PolicyRuleUpdate struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []string `json:"destinations"` + Destinations *[]string `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -878,10 +1050,11 @@ type PolicyRuleUpdate struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleUpdateProtocol `json:"protocol"` + Protocol PolicyRuleUpdateProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []string `json:"sources"` + Sources *[]string `json:"sources,omitempty"` } // PolicyRuleUpdateAction Policy rule accept or drops packets @@ -893,14 +1066,11 @@ type PolicyRuleUpdateProtocol string // PolicyUpdate defines model for PolicyUpdate. type PolicyUpdate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` - // Name Policy name identifier Name string `json:"name"` @@ -955,6 +1125,16 @@ type ProcessCheck struct { Processes []Process `json:"processes"` } +// Resource defines model for Resource. +type Resource struct { + // Id ID of the resource + Id string `json:"id"` + Type ResourceType `json:"type"` +} + +// ResourceType defines model for ResourceType. +type ResourceType string + // Route defines model for Route. type Route struct { // AccessControlGroups Access control group identifier associated with route. @@ -1062,7 +1242,94 @@ type SetupKey struct { // Id Setup Key ID Id string `json:"id"` - // Key Setup Key value + // Key Setup Key as secret + Key string `json:"key"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyBase defines model for SetupKeyBase. +type SetupKeyBase struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyClear defines model for SetupKeyClear. +type SetupKeyClear struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // Key Setup Key as plain text Key string `json:"key"` // LastUsed Setup key last usage date @@ -1098,23 +1365,8 @@ type SetupKeyRequest struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key AutoGroups []string `json:"auto_groups"` - // Ephemeral Indicate that the peer will be ephemeral or not - Ephemeral *bool `json:"ephemeral,omitempty"` - - // ExpiresIn Expiration time in seconds, 0 will mean the key never expires - ExpiresIn int `json:"expires_in"` - - // Name Setup Key name - Name string `json:"name"` - // Revoked Setup key revocation status Revoked bool `json:"revoked"` - - // Type Setup key type, one-off for single time usage and reusable - Type string `json:"type"` - - // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. - UsageLimit int `json:"usage_limit"` } // User defines model for User. @@ -1220,6 +1472,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest // PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType. type PutApiGroupsGroupIdJSONRequestBody = GroupRequest +// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType. +type PostApiNetworksJSONRequestBody = NetworkRequest + +// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType. +type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest + +// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType. +type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest + +// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType. +type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest + +// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType. +type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest + +// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType. +type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest + // PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType. type PutApiPeersPeerIdJSONRequestBody = PeerRequest @@ -1227,7 +1497,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest type PostApiPoliciesJSONRequestBody = PolicyUpdate // PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType. -type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate +type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate // PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType. type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate diff --git a/management/server/http/configs/auth.go b/management/server/http/configs/auth.go new file mode 100644 index 000000000..aa91fa55b --- /dev/null +++ b/management/server/http/configs/auth.go @@ -0,0 +1,9 @@ +package configs + +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index bb6d00209..1ddf10a6c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,35 +12,31 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroups "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/handlers/accounts" + "github.com/netbirdio/netbird/management/server/http/handlers/dns" + "github.com/netbirdio/netbird/management/server/http/handlers/events" + "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/networks" + "github.com/netbirdio/netbird/management/server/http/handlers/peers" + "github.com/netbirdio/netbird/management/server/http/handlers/policies" + "github.com/netbirdio/netbird/management/server/http/handlers/routes" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbnetworks "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) const apiPrefix = "/api" -// AuthCfg contains parameters for authentication middleware -type AuthCfg struct { - Issuer string - Audience string - UserIDClaim string - KeysLocation string -} - -type apiHandler struct { - Router *mux.Router - AccountManager s.AccountManager - geolocationManager *geolocation.Geolocation - AuthCfg AuthCfg -} - -// EmptyObject is an empty struct used to return empty JSON object -type emptyObject struct { -} - -// APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. +func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -75,133 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa router := rootRouter.PathPrefix(prefix).Subrouter() router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) - api := apiHandler{ - Router: router, - AccountManager: accountManager, - geolocationManager: LocationManager, - AuthCfg: authCfg, - } - - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } - api.addAccountsEndpoint() - api.addPeersEndpoint() - api.addUsersEndpoint() - api.addUsersTokensEndpoint() - api.addSetupKeysEndpoint() - api.addPoliciesEndpoint() - api.addGroupsEndpoint() - api.addRoutesEndpoint() - api.addDNSNameserversEndpoint() - api.addDNSSettingEndpoint() - api.addEventsEndpoint() - api.addPostureCheckEndpoint() - api.addLocationsEndpoint() + accounts.AddEndpoints(accountManager, authCfg, router) + peers.AddEndpoints(accountManager, authCfg, router) + users.AddEndpoints(accountManager, authCfg, router) + setup_keys.AddEndpoints(accountManager, authCfg, router) + policies.AddEndpoints(accountManager, LocationManager, authCfg, router) + groups.AddEndpoints(accountManager, authCfg, router) + routes.AddEndpoints(accountManager, authCfg, router) + dns.AddEndpoints(accountManager, authCfg, router) + events.AddEndpoints(accountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } - -func (apiHandler *apiHandler) addAccountsEndpoint() { - accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPeersEndpoint() { - peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersEndpoint() { - userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersTokensEndpoint() { - tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addSetupKeysEndpoint() { - keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addPoliciesEndpoint() { - policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addGroupsEndpoint() { - groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addRoutesEndpoint() { - routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSNameserversEndpoint() { - nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSSettingEndpoint() { - dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") -} - -func (apiHandler *apiHandler) addEventsEndpoint() { - eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPostureCheckEndpoint() { - postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addLocationsEndpoint() { - locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS") -} diff --git a/management/server/http/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go similarity index 73% rename from management/server/http/accounts_handler.go rename to management/server/http/handlers/accounts/accounts_handler.go index 4baf9c692..a23628cdc 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "encoding/json" @@ -10,20 +10,29 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// AccountsHandler is a handler that handles the server.Account HTTP endpoints -type AccountsHandler struct { +// handler is a handler that handles the server.Account HTTP endpoints +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewAccountsHandler creates a new AccountsHandler HTTP handler -func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { - return &AccountsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + accountsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") +} + +// newHandler creates a new handler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +41,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { +// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -51,8 +60,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } -// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { +// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -74,7 +83,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) return } - settings := &server.Settings{ + settings := &types.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, @@ -99,6 +108,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) if req.Settings.JwtAllowGroups != nil { settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -111,8 +123,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteAccount is a HTTP DELETE handler to delete an account -func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { +// deleteAccount is a HTTP DELETE handler to delete an account +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -127,10 +139,10 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *server.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -146,6 +158,7 @@ func toAccountResponse(accountID string, settings *server.Settings) *api.Account JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, + RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, } if settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go similarity index 73% rename from management/server/http/accounts_handler_test.go rename to management/server/http/handlers/accounts/accounts_handler_test.go index cacb3d430..e8a599863 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "bytes" @@ -13,23 +13,23 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { - return &AccountsHandler{ +func initAccountsTestData(account *types.Account, admin *types.User) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil }, - GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") sr := func(v string) *string { return &v } br := func(v bool) *bool { return &v } - handler := initAccountsTestData(&server.Account{ + handler := initAccountsTestData(&types.Account{ Id: accountID, Domain: "hotmail.com", - Network: server.NewNetwork(), - Users: map[string]*server.User{ + Network: types.NewNetwork(), + Users: map[string]*types.User{ adminUser.Id: adminUser, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, @@ -89,19 +89,20 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllAccounts OK", + name: "getAllAccounts OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/accounts", expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: int(time.Hour.Seconds()), - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: int(time.Hour.Seconds()), + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -114,13 +115,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: false, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -133,13 +135,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr("roles"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{"test"}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -152,13 +155,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 554400, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(true), - JwtGroupsClaimName: sr("groups"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 554400, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(true), + JwtGroupsClaimName: sr("groups"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -189,8 +193,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go similarity index 59% rename from management/server/http/dns_settings_handler.go rename to management/server/http/handlers/dns/dns_settings_handler.go index 13c2101a7..112eee179 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -1,26 +1,40 @@ -package http +package dns import ( "encoding/json" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) -// DNSSettingsHandler is a handler that returns the DNS settings of the account -type DNSSettingsHandler struct { +// dnsSettingsHandler is a handler that returns the DNS settings of the account +type dnsSettingsHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler -func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { - return &DNSSettingsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + addDNSSettingEndpoint(accountManager, authCfg, router) + addDNSNameserversEndpoint(accountManager, authCfg, router) +} + +func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) + router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") +} + +// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler +func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -29,8 +43,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetDNSSettings returns the DNS settings for the account -func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { +// getDNSSettings returns the DNS settings for the account +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -52,8 +66,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque util.WriteJSONObject(r.Context(), w, apiDNSSettings) } -// UpdateDNSSettings handles update to DNS settings of an account -func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { +// updateDNSSettings handles update to DNS settings of an account +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -68,7 +82,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re return } - updateDNSSettings := &server.DNSSettings{ + updateDNSSettings := &types.DNSSettings{ DisabledManagementGroups: req.DisabledManagementGroups, } diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go similarity index 86% rename from management/server/http/dns_settings_handler_test.go rename to management/server/http/handlers/dns/dns_settings_handler_test.go index 8baea7b15..9ca1dc032 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -13,10 +13,10 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -27,26 +27,26 @@ const ( testDNSSettingsUserID = "test_user" ) -var baseExistingDNSSettings = server.DNSSettings{ +var baseExistingDNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, } -var testingDNSSettingsAccount = &server.Account{ +var testingDNSSettingsAccount = &types.Account{ Id: testDNSSettingsAccountID, Domain: "hotmail.com", - Users: map[string]*server.User{ - testDNSSettingsUserID: server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + testDNSSettingsUserID: types.NewAdminUser("test_user"), }, DNSSettings: baseExistingDNSSettings, } -func initDNSSettingsTestData() *DNSSettingsHandler { - return &DNSSettingsHandler{ +func initDNSSettingsTestData() *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave != nil { return nil } @@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go similarity index 77% rename from management/server/http/nameservers_handler.go rename to management/server/http/handlers/dns/nameservers_handler.go index e7a2bc2ae..09047e231 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -1,4 +1,4 @@ -package http +package dns import ( "encoding/json" @@ -11,20 +11,30 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// NameserversHandler is the nameserver group handler of the account -type NameserversHandler struct { +// nameserversHandler is the nameserver group handler of the account +type nameserversHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewNameserversHandler returns a new instance of NameserversHandler handler -func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { - return &NameserversHandler{ +func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager, authCfg) + router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") +} + +// newNameserversHandler returns a new instance of nameserversHandler handler +func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { + return &nameserversHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetAllNameservers returns the list of nameserver groups for the account -func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { +// getAllNameservers returns the list of nameserver groups for the account +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, apiNameservers) } -// CreateNameserverGroup handles nameserver group creation request -func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// createNameserverGroup handles nameserver group creation request +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// UpdateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// updateNameserverGroup handles update to a nameserver group identified by a given ID +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteNameserverGroup handles nameserver group deletion request -func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { +// deleteNameserverGroup handles nameserver group deletion request +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetNameserverGroup handles a nameserver group Get request identified by ID -func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { +// getNameserverGroup handles a nameserver group Get request identified by ID +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go similarity index 95% rename from management/server/http/nameservers_handler_test.go rename to management/server/http/handlers/dns/nameservers_handler_test.go index 98c2e402d..c6561e4d8 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ Enabled: true, } -func initNameserversTestData() *NameserversHandler { - return &NameserversHandler{ +func initNameserversTestData() *nameserversHandler { + return &nameserversHandler{ accountManager: &mock_server.MockAccountManager{ GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { @@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") + router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/events_handler.go b/management/server/http/handlers/events/events_handler.go similarity index 79% rename from management/server/http/events_handler.go rename to management/server/http/handlers/events/events_handler.go index ee0c63f28..62da59535 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,28 +1,35 @@ -package http +package events import ( "context" "fmt" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// EventsHandler HTTP handler -type EventsHandler struct { +// handler HTTP handler +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewEventsHandler creates a new EventsHandler HTTP handler -func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { - return &EventsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + eventsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") +} + +// newHandler creates a new events handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev } } -// GetAllEvents list of the given account -func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { +// getAllEvents list of the given account +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { +func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { diff --git a/management/server/http/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go similarity index 95% rename from management/server/http/events_handler_test.go rename to management/server/http/handlers/events/events_handler_test.go index e525cf2ee..17478aba3 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -1,4 +1,4 @@ -package http +package events import ( "context" @@ -13,15 +13,15 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" ) -func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { - return &EventsHandler{ +func initEventsTestData(account string, events ...*activity.Event) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { @@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *EventsHandle GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - return make([]*server.UserInfo, 0), nil + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + return make([]*types.UserInfo, 0), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllEvents OK", + name: "getAllEvents OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/events/", @@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) { }, } accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) handler := initEventsTestData(accountID, events...) @@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") + router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go similarity index 71% rename from management/server/http/groups_handler.go rename to management/server/http/handlers/groups/groups_handler.go index f369d1a00..0ecea7ec2 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -1,30 +1,41 @@ -package http +package groups import ( "encoding/json" "net/http" "github.com/gorilla/mux" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/http/configs" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// GroupsHandler is a handler that returns groups of the account -type GroupsHandler struct { +// handler is a handler that returns groups of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGroupsHandler creates a new GroupsHandler HTTP handler -func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { - return &GroupsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + groupsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new groups handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr } } -// GetAllGroups list for the account -func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, groupsResponse) } -// UpdateGroup handles update to a group identified by a given ID -func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { +// updateGroup handles update to a group identified by a given ID +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -118,10 +129,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ ID: groupID, Name: req.Name, Peers: peers, + Resources: resources, Issued: existingGroup.Issued, IntegrationReference: existingGroup.IntegrationReference, } @@ -141,8 +163,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// CreateGroup handles group creation request -func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { +// createGroup handles group creation request +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,10 +190,21 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ - Name: req.Name, - Peers: peers, - Issued: nbgroup.GroupIssuedAPI, + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ + Name: req.Name, + Peers: peers, + Resources: resources, + Issued: types.GroupIssuedAPI, } err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) @@ -189,8 +222,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// DeleteGroup handles group deletion request -func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { +// deleteGroup handles group deletion request +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -215,11 +248,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetGroup returns a group -func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { +// getGroup returns a group +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -248,13 +281,13 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } -func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { peersMap := make(map[string]*nbpeer.Peer, len(peers)) for _, peer := range peers { peersMap[peer.ID] = peer } - cache := make(map[string]api.PeerMinimum) + peerCache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, Name: group.Name, @@ -262,7 +295,7 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { } for _, pid := range group.Peers { - _, ok := cache[pid] + _, ok := peerCache[pid] if !ok { peer, ok := peersMap[pid] if !ok { @@ -272,12 +305,19 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { Id: peer.ID, Name: peer.Name, } - cache[pid] = peerResp + peerCache[pid] = peerResp gr.Peers = append(gr.Peers, peerResp) } } gr.PeersCount = len(gr.Peers) + for _, res := range group.Resources { + resResp := res.ToAPIResponse() + gr.Resources = append(gr.Resources, *resResp) + } + + gr.ResourcesCount = len(gr.Resources) + return &gr } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go similarity index 90% rename from management/server/http/groups_handler_test.go rename to management/server/http/handlers/groups/groups_handler_test.go index 7f3c81f18..49805ca9b 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -1,4 +1,4 @@ -package http +package groups import ( "bytes" @@ -17,13 +17,13 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -31,20 +31,20 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { - return &GroupsHandler{ +func initGroupTestData(initGroups ...*types.Group) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - groups := map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*types.Group, error) { + groups := map[string]*types.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, } for _, group := range initGroups { @@ -61,9 +61,9 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { - return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, fmt.Errorf("unknown group name") @@ -106,21 +106,21 @@ func TestGetGroup(t *testing.T) { requestBody io.Reader }{ { - name: "GetGroup OK", + name: "getGroup OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/groups/idofthegroup", expectedStatus: http.StatusOK, }, { - name: "GetGroup not found", + name: "getGroup not found", requestType: http.MethodGet, requestPath: "/api/groups/notexists", expectedStatus: http.StatusNotFound, }, } - group := &nbgroup.Group{ + group := &types.Group{ ID: "idofthegroup", Name: "Group", } @@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -154,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &nbgroup.Group{} + got := &types.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } @@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") + router.HandleFunc("/api/groups", p.createGroup).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go new file mode 100644 index 000000000..316b93611 --- /dev/null +++ b/management/server/http/handlers/networks/handler.go @@ -0,0 +1,321 @@ +package networks + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/status" + nbtypes "github.com/netbirdio/netbird/management/server/types" +) + +// handler is a handler that returns networks of the account +type handler struct { + networksManager networks.Manager + resourceManager resources.Manager + routerManager routers.Manager + accountManager s.AccountManager + + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + addRouterEndpoints(routerManager, extractFromToken, authCfg, router) + addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) + + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) + router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") +} + +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { + return &handler{ + networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, + groupsManager: groupsManager, + accountManager: accountManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) +} + +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.AccountID = accountID + network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) +} + +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.ID = networkID + network.AccountID = accountID + network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, int, error) { + resources, err := h.resourceManager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get resources in network: %w", err) + } + + var resourceIDs []string + for _, resource := range resources { + resourceIDs = append(resourceIDs, resource.ID) + } + + routers, err := h.routerManager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) + } + + groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) + } + + peerCounter := 0 + var routerIDs []string + for _, router := range routers { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + peerCounter += len(groups[groupID].Peers) + } + } + } + + return routerIDs, resourceIDs, peerCounter, nil +} + +func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network { + var networkResponse []*api.Network + for _, network := range networks { + routerIDs, peerCounter := getRouterIDs(network, routers, groups) + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs)) + } + return networkResponse +} + +func getRouterIDs(network *types.Network, routers map[string][]*routerTypes.NetworkRouter, groups map[string]*nbtypes.Group) ([]string, int) { + routerIDs := []string{} + peerCounter := 0 + for _, router := range routers[network.ID] { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + group, ok := groups[groupID] + if !ok { + continue + } + peerCounter += len(group.Peers) + } + } + } + return routerIDs, peerCounter +} diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go new file mode 100644 index 000000000..6499bd652 --- /dev/null +++ b/management/server/http/handlers/networks/resources_handler.go @@ -0,0 +1,223 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) + +type resourceHandler struct { + resourceManager resources.Manager + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg) + router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") +} + +func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { + return &resourceHandler{ + resourceManager: resourceManager, + groupsManager: groupsManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} + +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource.Enabled = true + resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.ID = mux.Vars(r)["resourceId"] + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go new file mode 100644 index 000000000..7ca95d902 --- /dev/null +++ b/management/server/http/handlers/networks/routers_handler.go @@ -0,0 +1,165 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/routers/types" +) + +type routersHandler struct { + routersManager routers.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) + router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") +} + +func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { + return &routersHandler{ + routersManager: routersManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var routersResponse []*api.NetworkRouter + for _, router := range routers { + routersResponse = append(routersResponse, router.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, routersResponse) +} + +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = networkID + router.AccountID = accountID + router.Enabled = true + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = mux.Vars(r)["networkId"] + router.ID = mux.Vars(r)["routerId"] + router.AccountID = accountID + + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, struct{}{}) +} diff --git a/management/server/http/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go similarity index 82% rename from management/server/http/peers_handler.go rename to management/server/http/handlers/peers/peers_handler.go index 235e744b3..76a0149c6 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,4 +1,4 @@ -package http +package peers import ( "context" @@ -10,23 +10,33 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// PeersHandler is a handler that returns peers of the account -type PeersHandler struct { +// Handler is a handler that returns peers of the account +type Handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPeersHandler creates a new PeersHandler HTTP handler -func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { - return &PeersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + peersHandler := NewHandler(accountManager, authCfg) + router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") +} + +// NewHandler creates a new peers Handler +func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { + return &Handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -35,7 +45,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } -func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { peerToReturn := peer.Copy() if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected @@ -48,7 +58,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -62,12 +72,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID st } dnsDomain := h.accountManager.GetDNSDomain() - peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) - if err != nil { - util.WriteError(ctx, err, w) - return - } - groupsInfo := toGroupsInfo(peerGroups) + grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + groupsInfo := groups.ToGroupsInfo(grps, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -80,7 +86,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID st util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -116,7 +122,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID util.WriteError(ctx, err, w) return } - groupMinimumInfo := toGroupsInfo(peerGroups) + + groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -130,18 +137,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } -func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { +func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to delete peer: %v", err) util.WriteError(ctx, err, w) return } - util.WriteJSONObject(ctx, w, emptyObject{}) + util.WriteJSONObject(ctx, w, util.EmptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -171,7 +178,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { } // GetAllPeers returns a list of all peers associated with a provided account -func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -187,6 +194,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() + grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) @@ -194,13 +203,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - - peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - groupMinimumInfo := toGroupsInfo(peerGroups) + groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -216,7 +219,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] if !ok { @@ -226,7 +229,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,12 +281,12 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request dnsDomain := h.accountManager.GetDNSDomain() customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { +func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) @@ -312,18 +315,6 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { - groupsInfo := make([]api.GroupMinimum, 0, len(groups)) - for _, group := range groups { - groupsInfo = append(groupsInfo, api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - }) - } - return groupsInfo -} - func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { @@ -348,7 +339,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD UiVersion: peer.Meta.UIVersion, DnsLabel: fqdn(peer, dnsDomain), LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, + LastLogin: peer.GetLastLogin(), LoginExpired: peer.Status.LoginExpired, ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, @@ -382,7 +373,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn UiVersion: peer.Meta.UIVersion, DnsLabel: fqdn(peer, dnsDomain), LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, + LastLogin: peer.GetLastLogin(), LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, CountryCode: peer.Location.CountryCode, diff --git a/management/server/http/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go similarity index 94% rename from management/server/http/peers_handler_test.go rename to management/server/http/handlers/peers/peers_handler_test.go index 9279fc536..16065a677 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -1,4 +1,4 @@ -package http +package peers import ( "bytes" @@ -15,11 +15,10 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/assert" @@ -38,19 +37,19 @@ const ( userIDKey ctxKey = "user_id" ) -func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { +func initTestMetaData(peers ...*nbpeer.Peer) *Handler { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() } - policy := &server.Policy{ + policy := &types.Policy{ ID: "policy", AccountID: "test_id", Name: "policy", Enabled: true, - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "rule", Name: "rule", @@ -65,19 +64,19 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - srvUser := server.NewRegularUser(serviceUser) + srvUser := types.NewRegularUser(serviceUser) srvUser.IsServiceUser = true - account := &server.Account{ + account := &types.Account{ Id: "test_id", Domain: "hotmail.com", Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), + Users: map[string]*types.User{ + adminUser: types.NewAdminUser(adminUser), + regularUser: types.NewRegularUser(regularUser), serviceUser: srvUser, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "group1": { ID: "group1", AccountID: "test_id", @@ -86,12 +85,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { Peers: maps.Keys(peersMap), }, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ + Policies: []*types.Policy{policy}, + Network: &types.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ IP: net.ParseIP("100.67.0.0"), @@ -101,7 +100,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - return &PeersHandler{ + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer @@ -129,12 +128,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, - GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { peersID := make([]string, len(peers)) for _, peer := range peers { peersID = append(peersID, peer.ID) } - return []*nbgroup.Group{ + return []*types.Group{ { ID: "group1", AccountID: accountID, @@ -150,10 +149,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { return account, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go similarity index 91% rename from management/server/http/geolocation_handler_test.go rename to management/server/http/handlers/policies/geolocation_handler_test.go index 19c916dd2..fc5839baa 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "context" @@ -11,22 +11,22 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) -func initGeolocationTestData(t *testing.T) *GeolocationsHandler { +func initGeolocationTestData(t *testing.T) *geolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" - geonamesdbPath = "../testdata/geonames_20240305.db" + mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../../../testdata/geonames_20240305.db" ) tempDir := t.TempDir() @@ -41,13 +41,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) - return &GeolocationsHandler{ + return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { - return server.NewAdminUser(id), nil + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { + return types.NewAdminUser(id), nil }, }, geolocationManager: geo, @@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go similarity index 71% rename from management/server/http/geolocations_handler.go rename to management/server/http/handlers/policies/geolocations_handler.go index 418228abf..161d97402 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "net/http" @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -18,16 +19,22 @@ var ( countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") ) -// GeolocationsHandler is a handler that returns locations. -type GeolocationsHandler struct { +// geolocationsHandler is a handler that returns locations. +type geolocationsHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGeolocationsHandlerHandler creates a new Geolocations handler -func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { - return &GeolocationsHandler{ +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") +} + +// newGeolocationsHandlerHandler creates a new Geolocations handler +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { + return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca } } -// GetAllCountries retrieves a list of all countries -func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { +// getAllCountries retrieves a list of all countries +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req util.WriteJSONObject(r.Context(), w, countries) } -// GetCitiesByCountry retrieves a list of cities based on the given country code -func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { +// getCitiesByCountry retrieves a list of cities based on the given country code +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { +func (l *geolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go similarity index 64% rename from management/server/http/policies_handler.go rename to management/server/http/handlers/policies/policies_handler.go index 8255e4896..a748e73b8 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -6,23 +6,36 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// Policies is a handler that returns policy of the account -type Policies struct { +// handler is a handler that returns policy of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPoliciesHandler creates a new Policies handler -func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { - return &Policies{ +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + policiesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) +} + +// newHandler creates a new policies handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllPolicies list for the account -func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { +// getAllPolicies list for the account +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -65,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, policies) } -// UpdatePolicy handles update to a policy identified by a given ID -func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { +// updatePolicy handles update to a policy identified by a given ID +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { h.savePolicy(w, r, accountID, userID, policyID) } -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { +// createPolicy handles policy creation request +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -103,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { } // savePolicy handles policy creation and update -func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -120,23 +133,74 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - policy := &server.Policy{ + description := "" + if req.Description != nil { + description = *req.Description + } + + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, Enabled: req.Enabled, - Description: req.Description, + Description: description, } for _, rule := range req.Rules { - pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + var ruleID string + if rule.Id != nil && policyID != "" { + ruleID = *rule.Id + } + + hasSources := rule.Sources != nil + hasSourceResource := rule.SourceResource != nil + + hasDestinations := rule.Destinations != nil + hasDestinationResource := rule.DestinationResource != nil + + if hasSources && hasSourceResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources, not both"), w) + return + } + + if hasDestinations && hasDestinationResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either destinations or destination resources, not both"), w) + return + } + + if !(hasSources || hasSourceResource) || !(hasDestinations || hasDestinationResource) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources and destinations or destination resources"), w) + return + } + + pr := types.PolicyRule{ + ID: ruleID, PolicyID: policyID, Name: rule.Name, - Destinations: rule.Destinations, - Sources: rule.Sources, Bidirectional: rule.Bidirectional, } + if hasSources { + pr.Sources = *rule.Sources + } + + if hasSourceResource { + // TODO: validate the resource id and type + sourceResource := &types.Resource{} + sourceResource.FromAPIRequest(rule.SourceResource) + pr.SourceResource = *sourceResource + } + + if hasDestinations { + pr.Destinations = *rule.Destinations + } + + if hasDestinationResource { + // TODO: validate the resource id and type + destinationResource := &types.Resource{} + destinationResource.FromAPIRequest(rule.DestinationResource) + pr.DestinationResource = *destinationResource + } + pr.Enabled = rule.Enabled if rule.Description != nil { pr.Description = *rule.Description @@ -144,9 +208,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID switch rule.Action { case api.PolicyRuleUpdateActionAccept: - pr.Action = server.PolicyTrafficActionAccept + pr.Action = types.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: - pr.Action = server.PolicyTrafficActionDrop + pr.Action = types.PolicyTrafficActionDrop default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return @@ -154,13 +218,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: - pr.Protocol = server.PolicyRuleProtocolALL + pr.Protocol = types.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: - pr.Protocol = server.PolicyRuleProtocolTCP + pr.Protocol = types.PolicyRuleProtocolTCP case api.PolicyRuleUpdateProtocolUdp: - pr.Protocol = server.PolicyRuleProtocolUDP + pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: - pr.Protocol = server.PolicyRuleProtocolICMP + pr.Protocol = types.PolicyRuleProtocolICMP default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -187,7 +251,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } - pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + pr.PortRanges = append(pr.PortRanges, types.RulePortRange{ Start: uint16(portRange.Start), End: uint16(portRange.End), }) @@ -196,7 +260,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID // validate policy object switch pr.Protocol { - case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return @@ -205,7 +269,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } - case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP: if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return @@ -240,8 +304,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteJSONObject(r.Context(), w, resp) } -// DeletePolicy handles policy deletion request -func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { +// deletePolicy handles policy deletion request +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -261,11 +325,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetPolicy handles a group Get request identified by ID -func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { +// getPolicy handles a group Get request identified by ID +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -301,8 +365,8 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { - groupsMap := make(map[string]*nbgroup.Group) +func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -311,7 +375,7 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic ap := &api.Policy{ Id: &policy.ID, Name: policy.Name, - Description: policy.Description, + Description: &policy.Description, Enabled: policy.Enabled, SourcePostureChecks: policy.SourcePostureChecks, } @@ -319,13 +383,15 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rID := r.ID rDescription := r.Description rule := api.PolicyRule{ - Id: &rID, - Name: r.Name, - Enabled: r.Enabled, - Description: &rDescription, - Bidirectional: r.Bidirectional, - Protocol: api.PolicyRuleProtocol(r.Protocol), - Action: api.PolicyRuleAction(r.Action), + Id: &rID, + Name: r.Name, + Enabled: r.Enabled, + Description: &rDescription, + Bidirectional: r.Bidirectional, + Protocol: api.PolicyRuleProtocol(r.Protocol), + Action: api.PolicyRuleAction(r.Action), + SourceResource: r.SourceResource.ToAPIResponse(), + DestinationResource: r.DestinationResource.ToAPIResponse(), } if len(r.Ports) != 0 { @@ -344,26 +410,30 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.PortRanges = &portRanges } + var sources []api.GroupMinimum for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, PeersCount: len(group.Peers), } - rule.Sources = append(rule.Sources, minimum) + sources = append(sources, minimum) cache[gid] = minimum } } + rule.Sources = &sources + var destinations []api.GroupMinimum for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { - rule.Destinations = append(rule.Destinations, cachedMinimum) + destinations = append(destinations, cachedMinimum) continue } if group, ok := groupsMap[gid]; ok { @@ -372,10 +442,12 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic Name: group.Name, PeersCount: len(group.Peers), } - rule.Destinations = append(rule.Destinations, minimum) + destinations = append(destinations, minimum) cache[gid] = minimum } } + rule.Destinations = &destinations + ap.Rules = append(ap.Rules, rule) } return ap diff --git a/management/server/http/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go similarity index 81% rename from management/server/http/policies_handler_test.go rename to management/server/http/handlers/policies/policies_handler_test.go index f8a897eb2..3e1be187c 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -10,9 +10,9 @@ import ( "strings" "testing" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" @@ -20,50 +20,49 @@ import ( "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *Policies { - testPolicies := make(map[string]*server.Policy, len(policies)) +func initPoliciesTestData(policies ...*types.Policy) *handler { + testPolicies := make(map[string]*types.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } - return &Policies{ + return &handler{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return policy, nil }, - GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { - return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - user := server.NewAdminUser(userID) - return &server.Account{ + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user := types.NewAdminUser(userID) + return &types.Account{ Id: accountID, Domain: "hotmail.com", - Policies: []*server.Policy{ + Policies: []*types.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "test_user": user, }, }, nil @@ -91,24 +90,24 @@ func TestPoliciesGetPolicy(t *testing.T) { requestBody io.Reader }{ { - name: "GetPolicy OK", + name: "getPolicy OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/policies/idofthepolicy", expectedStatus: http.StatusOK, }, { - name: "GetPolicy not found", + name: "getPolicy not found", requestType: http.MethodGet, requestPath: "/api/policies/notexists", expectedStatus: http.StatusNotFound, }, } - policy := &server.Policy{ + policy := &types.Policy{ ID: "idofthepolicy", Name: "Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ {ID: "idoftherule", Name: "Rule"}, }, } @@ -121,7 +120,7 @@ func TestPoliciesGetPolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -155,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) { func TestPoliciesWritePolicy(t *testing.T) { str := func(s string) *string { return &s } + emptyString := "" tt := []struct { name string expectedStatus int @@ -177,14 +177,17 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["G"] } ]}`)), expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-was-set"), - Name: "Default POSTed Policy", + Id: str("id-was-set"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-was-set"), @@ -193,6 +196,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "G"}}, }, }, }, @@ -221,14 +226,17 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["F"] } ]}`)), expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-existed"), - Name: "Default POSTed Policy", + Id: str("id-existed"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-existed"), @@ -237,6 +245,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "F"}}, }, }, }, @@ -251,10 +261,10 @@ func TestPoliciesWritePolicy(t *testing.T) { }, } - p := initPoliciesTestData(&server.Policy{ + p := initPoliciesTestData(&types.Policy{ ID: "id-existed", Name: "Default POSTed Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "id-existed", Name: "Default POSTed Rule", @@ -269,8 +279,8 @@ func TestPoliciesWritePolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go similarity index 69% rename from management/server/http/posture_checks_handler.go rename to management/server/http/handlers/policies/posture_checks_handler.go index 2c8204292..ce0d4878c 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -9,22 +9,33 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PostureChecksHandler is a handler that returns posture checks of the account. -type PostureChecksHandler struct { +// postureChecksHandler is a handler that returns posture checks of the account. +type postureChecksHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPostureChecksHandler creates a new PostureChecks handler -func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { - return &PostureChecksHandler{ +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + addLocationsEndpoint(accountManager, locationManager, authCfg, router) +} + +// newPostureChecksHandler creates a new PostureChecks handler +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { + return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa } } -// GetAllPostureChecks list for the account -func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { +// getAllPostureChecks list for the account +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, postureChecks) } -// UpdatePostureCheck handles update to a posture check identified by a given ID -func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { +// updatePostureCheck handles update to a posture check identified by a given ID +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, postureChecksID) } -// CreatePostureCheck handles posture check creation request -func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { +// createPostureCheck handles posture check creation request +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, "") } -// GetPostureCheck handles a posture check Get request identified by ID -func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { +// getPostureCheck handles a posture check Get request identified by ID +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } -// DeletePostureCheck handles posture check deletion request -func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { +// deletePostureCheck handles posture check deletion request +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go similarity index 96% rename from management/server/http/posture_checks_handler_test.go rename to management/server/http/handlers/policies/posture_checks_handler_test.go index f400cec81..237687fd4 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -25,13 +25,13 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" -func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { +func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { testPostureChecks[postureCheck.ID] = postureCheck } - return &PostureChecksHandler{ + return &postureChecksHandler{ accountManager: &mock_server.MockAccountManager{ GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] @@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return claims.AccountId, claims.UserId, nil }, }, - geolocationManager: &geolocation.Geolocation{}, + geolocationManager: &geolocation.Mock{}, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ @@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) { requestBody io.Reader }{ { - name: "GetPostureCheck NBVersion OK", + name: "getPostureCheck NBVersion OK", expectedBody: true, id: postureCheck.ID, checkName: postureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck OSVersion OK", + name: "getPostureCheck OSVersion OK", expectedBody: true, id: osPostureCheck.ID, checkName: osPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck GeoLocation OK", + name: "getPostureCheck GeoLocation OK", expectedBody: true, id: geoPostureCheck.ID, checkName: geoPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck PrivateNetwork OK", + name: "getPostureCheck PrivateNetwork OK", expectedBody: true, id: privateNetworkCheck.ID, checkName: privateNetworkCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck Not Found", + name: "getPostureCheck Not Found", id: "not-exists", expectedStatus: http.StatusNotFound, }, @@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) { requestType string requestPath string requestBody io.Reader - setupHandlerFunc func(handler *PostureChecksHandler) + setupHandlerFunc func(handler *postureChecksHandler) }{ { name: "Create Posture Checks NB version", @@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go similarity index 83% rename from management/server/http/routes_handler.go rename to management/server/http/handlers/routes/routes_handler.go index f44a164e2..a29ba4562 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -1,4 +1,4 @@ -package http +package routes import ( "encoding/json" @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +24,24 @@ import ( const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" -// RoutesHandler is the routes handler of the account -type RoutesHandler struct { +// handler is the routes handler of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewRoutesHandler returns a new instance of RoutesHandler handler -func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { - return &RoutesHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + routesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") +} + +// newHandler returns a new instance of routes handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro } } -// GetAllRoutes returns the list of routes for the account -func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { +// getAllRoutes returns the list of routes for the account +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRoutes) } -// CreateRoute handles route creation request -func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { +// createRoute handles route creation request +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { +func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } @@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro return nil } -// UpdateRoute handles update to a route identified by a given ID -func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { +// updateRoute handles update to a route identified by a given ID +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -// DeleteRoute handles route deletion request -func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { +// deleteRoute handles route deletion request +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetRoute handles a route Get request identified by ID -func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { +// getRoute handles a route Get request identified by ID +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -350,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) { return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } - domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainList domain.List diff --git a/management/server/http/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go similarity index 92% rename from management/server/http/routes_handler_test.go rename to management/server/http/handlers/routes/routes_handler_test.go index 83bd7004d..45c465587 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -1,4 +1,4 @@ -package http +package routes import ( "bytes" @@ -11,18 +11,19 @@ import ( "net/netip" "testing" + "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/gorilla/mux" "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -61,7 +62,7 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &server.Account{ +var testingAccount = &types.Account{ Id: testAccountID, Domain: "hotmail.com", Peers: map[string]*nbpeer.Peer{ @@ -82,13 +83,13 @@ var testingAccount = &server.Account{ }, }, }, - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + "test_user": types.NewAdminUser("test_user"), }, } -func initRoutesTestData() *RoutesHandler { - return &RoutesHandler{ +func initRoutesTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { @@ -152,7 +153,7 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - //return testingAccount, testingAccount.Users["test_user"], nil + // return testingAccount, testingAccount.Users["test_user"], nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, @@ -239,7 +240,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -259,7 +260,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "domainNet", - Network: toPtr("invalid Prefix"), + Network: util.ToPtr("invalid Prefix"), KeepRoute: true, Domains: &[]string{existingDomain}, Peer: &existingPeerID, @@ -281,7 +282,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -330,6 +331,14 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST Wildcard Domain", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: false, + }, { name: "POST UnprocessableEntity when both network and domains are provided", requestType: http.MethodPost, @@ -377,7 +386,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -396,7 +405,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("invalid Prefix"), + Network: util.ToPtr("invalid Prefix"), Domains: &[]string{existingDomain}, Peer: &existingPeerID, NetworkType: route.DomainNetworkString, @@ -417,7 +426,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &emptyString, PeerGroups: &[]string{existingGroupID}, NetworkType: route.IPv4NetworkString, @@ -521,10 +530,10 @@ func TestRoutesHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") + router.HandleFunc("/api/routes", p.createRoute).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -609,6 +618,30 @@ func TestValidateDomains(t *testing.T) { expected: domain.List{"google.com"}, wantErr: true, }, + { + name: "Valid wildcard domain", + domains: []string{"*.example.com"}, + expected: domain.List{"*.example.com"}, + wantErr: false, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid wildcard domain", + domains: []string{"a.*.example.com"}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests { @@ -631,7 +664,3 @@ func toApiRoute(t *testing.T, r *route.Route) *api.Route { require.NoError(t, err, "Failed to convert route") return apiRoute } - -func toPtr[T any](v T) *T { - return &v -} diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go similarity index 70% rename from management/server/http/setupkeys_handler.go rename to management/server/http/handlers/setup_keys/setupkeys_handler.go index 31859f59b..67e296901 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "context" @@ -10,20 +10,31 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// SetupKeysHandler is a handler that returns a list of setup keys of the account -type SetupKeysHandler struct { +// handler is a handler that returns a list of setup keys of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler -func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { - return &SetupKeysHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + keysHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new setup key handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +43,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) } } -// CreateSetupKey is a POST requests that creates a new SetupKey -func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { +// createSetupKey is a POST requests that creates a new SetupKey +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -53,8 +64,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request return } - if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || - server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { + if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable || + types.SetupKeyType(req.Type) == types.SetupKeyOneOff) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -75,22 +86,22 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return } - apiSetupKeys := toResponseBody(setupKey) + apiSetupKeys := ToResponseBody(setupKey) // for the creation we need to send the plain key apiSetupKeys.Key = setupKey.Key util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -// GetSetupKey is a GET request to get a SetupKey by ID -func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { +// getSetupKey is a GET request to get a SetupKey by ID +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -114,8 +125,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { writeSuccess(r.Context(), w, key) } -// UpdateSetupKey is a PUT request to update server.SetupKey -func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { +// updateSetupKey is a PUT request to update server.SetupKey +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -137,20 +148,14 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request return } - if req.Name == "" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) - return - } - if req.AutoGroups == nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return } - newKey := &server.SetupKey{} + newKey := &types.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked - newKey.Name = req.Name newKey.Id = keyID newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) @@ -161,8 +166,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request writeSuccess(r.Context(), w, newKey) } -// GetAllSetupKeys is a GET request that returns a list of SetupKey -func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { +// getAllSetupKeys is a GET request that returns a list of SetupKey +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -178,13 +183,13 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { - apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) + apiSetupKeys = append(apiSetupKeys, ToResponseBody(key)) } util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -205,20 +210,20 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) - err := json.NewEncoder(w).Encode(toResponseBody(key)) + err := json.NewEncoder(w).Encode(ToResponseBody(key)) if err != nil { util.WriteError(ctx, err, w) return } } -func toResponseBody(key *server.SetupKey) *api.SetupKey { +func ToResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): @@ -235,12 +240,12 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey { Id: key.Id, Key: key.KeySecret, Name: key.Name, - Expires: key.ExpiresAt, + Expires: key.GetExpiresAt(), Type: string(key.Type), Valid: key.IsValid(), Revoked: key.Revoked, UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, + LastUsed: key.GetLastUsed(), State: state, AutoGroups: key.AutoGroups, UpdatedAt: key.UpdatedAt, diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go similarity index 82% rename from management/server/http/setupkeys_handler_test.go rename to management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 09256d0ea..f56227c10 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "bytes" @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -28,17 +28,17 @@ const ( notFoundSetupKeyID = "notFoundSetupKeyID" ) -func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, - user *server.User, -) *SetupKeysHandler { - return &SetupKeysHandler{ +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, + user *types.User, +) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, - ) (*server.SetupKey, error) { + ) (*types.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { nk := newKey.Copy() nk.Ephemeral = ephemeral @@ -46,7 +46,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -57,15 +57,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { - return []*server.SetupKey{defaultKey}, nil + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) { + return []*types.SetupKey{defaultKey}, nil }, DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { @@ -80,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup return jwtclaims.AuthorizationClaims{ UserId: user.Id, Domain: "hotmail.com", - AccountId: testAccountID, + AccountId: "testAccountId", } }), ), @@ -88,20 +88,20 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey, _ := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := types.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") - newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage, true) + newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"}, + types.SetupKeyUnlimitedUsage, true) newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true - expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey := ToResponseBody(newSetupKey) expectedNewKey.Key = plainKey tt := []struct { name string @@ -119,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys", expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)}, + expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)}, }, { name: "Get Existing Setup Key", @@ -127,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys/" + existingSetupKeyID, expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(defaultSetupKey), + expectedSetupKey: ToResponseBody(defaultSetupKey), }, { name: "Get Not Existing Setup Key", @@ -158,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) { ))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(updatedDefaultSetupKey), + expectedSetupKey: ToResponseBody(updatedDefaultSetupKey), }, { name: "Delete Setup Key", @@ -178,11 +178,11 @@ func TestSetupKeysHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -227,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) { func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { t.Helper() // this comparison is done manually because when converting to JSON dates formatted differently - // assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work + // assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "") assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "") assert.Equal(t, got.Name, expected.Name) diff --git a/management/server/http/pat_handler.go b/management/server/http/handlers/users/pat_handler.go similarity index 69% rename from management/server/http/pat_handler.go rename to management/server/http/handlers/users/pat_handler.go index dfa9563e3..7b93d2ae1 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -1,28 +1,37 @@ -package http +package users import ( "encoding/json" "net/http" - "time" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// PATHandler is the nameserver group handler of the account -type PATHandler struct { +// patHandler is the nameserver group handler of the account +type patHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPATsHandler creates a new PATHandler HTTP handler -func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { - return &PATHandler{ +func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager, authCfg) + router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") +} + +// newPATsHandler creates a new patHandler HTTP handler +func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { + return &patHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } -// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { +// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, patResponse) } -// GetToken is HTTP GET handler that returns a personal access token for the given user -func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { +// getToken is HTTP GET handler that returns a personal access token for the given user +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } -// CreateToken is HTTP POST handler that creates a personal access token for the given user -func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { +// createToken is HTTP POST handler that creates a personal access token for the given user +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } -// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { +// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -152,25 +161,21 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { - var lastUsed *time.Time - if !pat.LastUsed.IsZero() { - lastUsed = &pat.LastUsed - } +func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { return &api.PersonalAccessToken{ CreatedAt: pat.CreatedAt, CreatedBy: pat.CreatedBy, Name: pat.Name, - ExpirationDate: pat.ExpirationDate, + ExpirationDate: pat.GetExpirationDate(), Id: pat.ID, - LastUsed: lastUsed, + LastUsed: pat.LastUsed, } } -func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { +func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { return &api.PersonalAccessTokenGenerated{ PlainToken: pat.PlainToken, PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), diff --git a/management/server/http/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go similarity index 82% rename from management/server/http/pat_handler_test.go rename to management/server/http/handlers/users/pat_handler_test.go index c28228a50..9388067a4 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -12,13 +12,14 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -31,49 +32,49 @@ const ( testDomain = "hotmail.com" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ existingTokenID: { ID: existingTokenID, Name: "My first token", HashedToken: "someHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: existingUserID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, "token2": { ID: "token2", Name: "My second token", HashedToken: "someOtherHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: existingUserID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, }, } -func initPATTestData() *PATHandler { - return &PATHandler{ +func initPATTestData() *patHandler { + return &patHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return &server.PersonalAccessTokenGenerated{ + return &types.PersonalAccessTokenGenerated{ PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", - PersonalAccessToken: server.PersonalAccessToken{}, + PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, @@ -92,7 +93,7 @@ func initPATTestData() *PATHandler { } return nil }, - GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -104,14 +105,14 @@ func initPATTestData() *PATHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -186,10 +187,10 @@ func TestTokenHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -217,7 +218,7 @@ func TestTokenHandlers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } assert.NotEmpty(t, got.PlainToken) - assert.Equal(t, server.PATLength, len(got.PlainToken)) + assert.Equal(t, types.PATLength, len(got.PlainToken)) case "Get All Tokens": expectedTokens := []api.PersonalAccessToken{ toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), @@ -243,13 +244,13 @@ func TestTokenHandlers(t *testing.T) { } } -func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { +func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken { return api.PersonalAccessToken{ Id: serverToken.ID, Name: serverToken.Name, CreatedAt: serverToken.CreatedAt, - LastUsed: &serverToken.LastUsed, + LastUsed: serverToken.LastUsed, CreatedBy: serverToken.CreatedBy, - ExpirationDate: serverToken.ExpirationDate, + ExpirationDate: serverToken.GetExpirationDate(), } } diff --git a/management/server/http/users_handler.go b/management/server/http/handlers/users/users_handler.go similarity index 76% rename from management/server/http/users_handler.go rename to management/server/http/handlers/users/users_handler.go index 6e151a0da..7380dd97e 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,22 +9,34 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// UsersHandler is a handler that returns users of the account -type UsersHandler struct { +// handler is a handler that returns users of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewUsersHandler creates a new UsersHandler HTTP handler -func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { - return &UsersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + userHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + addUsersTokensEndpoint(accountManager, authCfg, router) +} + +// newHandler creates a new UsersHandler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +45,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use } } -// UpdateUser is a PUT requests to update User data -func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { +// updateUser is a PUT requests to update User data +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -72,13 +84,13 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - userRole := server.StrRoleToUserRole(req.Role) - if userRole == server.UserRoleUnknown { + userRole := types.StrRoleToUserRole(req.Role) + if userRole == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -94,8 +106,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// DeleteUser is a DELETE request to delete a user -func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { +// deleteUser is a DELETE request to delete a user +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -121,11 +133,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { +// createUser creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -145,7 +157,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { return } - if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { + if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -160,13 +172,13 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ Email: email, Name: name, Role: req.Role, AutoGroups: req.AutoGroups, IsServiceUser: req.IsServiceUser, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }) if err != nil { util.WriteError(r.Context(), err, w) @@ -175,9 +187,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// GetAllUsers returns a list of users of the account this user belongs to. +// getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -222,9 +234,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, users) } -// InviteUser resend invitations to users who haven't activated their accounts, +// inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -250,10 +262,10 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { +func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { autoGroups = []string{} diff --git a/management/server/http/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go similarity index 92% rename from management/server/http/users_handler_test.go rename to management/server/http/handlers/users/users_handler_test.go index f3d989da1..90081830a 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -26,54 +26,54 @@ const ( regularUserID = "regularUserID" ) -var usersTestAccount = &server.Account{ +var usersTestAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nonDeletableServiceUserID: { Id: serviceUserID, Role: "admin", IsServiceUser: true, NonDeletable: true, - Issued: server.UserIssuedIntegration, + Issued: types.UserIssuedIntegration, }, }, } -func initUsersTestData() *UsersHandler { - return &UsersHandler{ +func initUsersTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - users := make([]*server.UserInfo, 0) + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + users := make([]*types.UserInfo, 0) for _, v := range usersTestAccount.Users { - users = append(users, &server.UserInfo{ + users = append(users, &types.UserInfo{ ID: v.Id, Role: string(v.Role), Name: "", @@ -85,7 +85,7 @@ func initUsersTestData() *UsersHandler { } return users, nil }, - CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } @@ -100,7 +100,7 @@ func initUsersTestData() *UsersHandler { } return nil }, - SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -109,7 +109,7 @@ func initUsersTestData() *UsersHandler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) { requestPath string expectedUserIDs []string }{ - {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, + {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, } @@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - userHandler.GetAllUsers(recorder, req) + userHandler.getAllUsers(recorder, req) res := recorder.Result() defer res.Body.Close() @@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) { return } - respBody := []*server.UserInfo{} + respBody := []*types.UserInfo{} err = json.Unmarshal(content, &respBody) if err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) @@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) { requestType string requestPath string requestBody io.Reader - expectedResult []*server.User + expectedResult []*types.User }{ {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)}, // right now creation is blocked in AC middleware, will be refactored in the future @@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - userHandler.CreateUser(rr, req) + userHandler.createUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.InviteUser(rr, req) + userHandler.inviteUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.DeleteUser(rr, req) + userHandler.deleteUser(rr, req) res := rr.Result() defer res.Body.Close() diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 0ad250f43..c5bdf5fe7 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,16 +7,16 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 006a7872c..dcf73259a 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,16 +11,16 @@ import ( "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // GetAccountInfoFromPATFunc function -type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) +type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) @@ -159,7 +159,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ if err != nil { return fmt.Errorf("invalid Token: %w", err) } - if time.Now().After(pat.ExpirationDate) { + if time.Now().After(pat.GetExpirationDate()) { return fmt.Errorf("token expired") } @@ -173,6 +173,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory + claimMaps[jwtclaims.IsToken] = true jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 0e0872d31..7297e6ced 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -9,10 +9,11 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -28,10 +29,10 @@ const ( wrongToken = "wrongToken" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: accountID, Domain: domain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ userID: { Id: userID, AccountID: accountID, @@ -40,17 +41,17 @@ var testAccount = &server.Account{ ID: tokenID, Name: "My first token", HashedToken: "someHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: userID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, }, } -func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) { +func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go new file mode 100644 index 000000000..a4098f5d4 --- /dev/null +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -0,0 +1,178 @@ +//go:build benchmark +// +build benchmark + +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ + "Peers - XS": {Peers: 5, Groups: 10000, Users: 10000, SetupKeys: 10000}, + "Peers - S": {Peers: 100, Groups: 5, Users: 5, SetupKeys: 5}, + "Peers - M": {Peers: 1000, Groups: 20, Users: 20, SetupKeys: 100}, + "Peers - L": {Peers: 5000, Groups: 5, Users: 5, SetupKeys: 5}, + "Groups - L": {Peers: 5000, Groups: 10000, Users: 5, SetupKeys: 5}, + "Users - L": {Peers: 5000, Groups: 5, Users: 10000, SetupKeys: 5}, + "Setup Keys - L": {Peers: 5000, Groups: 5, Users: 5, SetupKeys: 10000}, + "Peers - XL": {Peers: 25000, Groups: 50, Users: 100, SetupKeys: 500}, +} + +func BenchmarkUpdatePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 3500}, + "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, + "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 500}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 300, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 300, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.PeerRequest{ + Name: "peer" + strconv.Itoa(i), + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/peers/"+testing_tools.TestPeerId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOnePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70}, + "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, + "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/peers/"+testing_tools.TestPeerId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllPeers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150}, + "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, + "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 5000, MaxMsPerOpCICD: 8000}, + "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2300}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/peers", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeletePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/peers/"+"oldpeer-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go new file mode 100644 index 000000000..ed643f75e --- /dev/null +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -0,0 +1,229 @@ +//go:build benchmark +// +build benchmark + +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ + "Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5}, + "Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100}, + "Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000}, + "Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000}, + "Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000}, +} + +func BenchmarkCreateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkUpdateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.SetupKeyRequest{ + AutoGroups: []string{groupId}, + Revoked: false, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllSetupKeys(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 15}, + "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, + "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go new file mode 100644 index 000000000..549a51c0e --- /dev/null +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -0,0 +1,185 @@ +//go:build benchmark +// +build benchmark + +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ + "Users - XS": {Peers: 10000, Groups: 10000, Users: 5, SetupKeys: 10000}, + "Users - S": {Peers: 5, Groups: 5, Users: 10, SetupKeys: 5}, + "Users - M": {Peers: 100, Groups: 20, Users: 1000, SetupKeys: 1000}, + "Users - L": {Peers: 5, Groups: 5, Users: 5000, SetupKeys: 5}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5000, SetupKeys: 5}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5000, SetupKeys: 5}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5000, SetupKeys: 10000}, + "Users - XL": {Peers: 500, Groups: 50, Users: 25000, SetupKeys: 3000}, +} + +func BenchmarkUpdateUser(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000}, + "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50}, + "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250}, + "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700}, + "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, + "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000}, + "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.UserRequest{ + AutoGroups: []string{groupId}, + IsBlocked: false, + Role: "admin", + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneUser(b *testing.B) { + b.Skip("Skipping benchmark as endpoint is missing") + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllUsers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, + "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteUsers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000}, + "Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230}, + "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, + "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800}, + "Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500}, + "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600}, + "Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/users/"+"olduser-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go new file mode 100644 index 000000000..ed6e642a2 --- /dev/null +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -0,0 +1,1149 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +func Test_SetupKeys_Create(t *testing.T) { + truePointer := true + + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.CreateSetupKeyRequest + requestType string + requestPath string + userId string + }{ + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with already existing name", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.ExistingKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key as on-off with more than one usage", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 3, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with expiration in the past", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: -testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key with AutoGroups that do exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key for ephemeral peers", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + Ephemeral: &truePointer, + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with AutoGroups that do not exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{"someGroupID"}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + req := testing_tools.BuildRequest(t, body, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.SetupKeyRequest + requestType string + requestPath string + requestId string + }{ + { + name: "Add existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Add non-existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, "someGroupId"}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Add existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Remove existing Group from existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Remove existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someID", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Revoke existing valid Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Un-Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Revoke existing expired Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "expired", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Get existing valid Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Get existing expired Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Get existing revoked Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Get non-existing Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectRespnose { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse []*api.SetupKey + requestType string + requestPath string + }{ + { + name: "Get all Setup Keys", + requestType: http.MethodGet, + requestPath: "/api/setup-keys", + expectedStatus: http.StatusOK, + expectedResponse: []*api.SetupKey{ + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := []api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + sort.Slice(got, func(i, j int) bool { + return got[i].UsageLimit < got[j].UsageLimit + }) + + sort.Slice(tc.expectedResponse, func(i, j int) bool { + return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit + }) + + for i := range tc.expectedResponse { + validateCreatedKey(t, tc.expectedResponse[i], &got[i]) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing valid Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Delete existing expired Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Delete existing revoked Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Delete non-existing Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + assert.Errorf(t, err, "Expected error when trying to get deleted key") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { + t.Helper() + + if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || + got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { + got.Expires = time.Time{} + expectedKey.Expires = time.Time{} + } + + if got.Id == "" { + t.Fatalf("Expected key to have an ID") + } + got.Id = "" + + if got.Key == "" { + t.Fatalf("Expected key to have a key") + } + got.Key = "" + + if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) { + got.UpdatedAt = time.Time{} + expectedKey.UpdatedAt = time.Time{} + } + + expectedKey.UpdatedAt = expectedKey.UpdatedAt.In(time.UTC) + got.UpdatedAt = got.UpdatedAt.In(time.UTC) + + assert.Equal(t, expectedKey, got) +} diff --git a/management/server/http/testing/testdata/peers.sql b/management/server/http/testing/testdata/peers.sql new file mode 100644 index 000000000..863eda520 --- /dev/null +++ b/management/server/http/testing/testdata/peers.sql @@ -0,0 +1,22 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); + +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql new file mode 100644 index 000000000..6d30fb5fe --- /dev/null +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime DEFAULT NULL,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); + + +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime DEFAULT NULL,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); + diff --git a/management/server/http/testing/testdata/users.sql b/management/server/http/testing/testdata/users.sql new file mode 100644 index 000000000..346f7b7ac --- /dev/null +++ b/management/server/http/testing/testdata/users.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + + +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go new file mode 100644 index 000000000..006d5679c --- /dev/null +++ b/management/server/http/testing/testing_tools/tools.go @@ -0,0 +1,311 @@ +package testing_tools + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + "github.com/netbirdio/netbird/management/server/util" + "github.com/stretchr/testify/assert" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + TestAccountId = "testAccountId" + TestPeerId = "testPeerId" + TestGroupId = "testGroupId" + TestKeyId = "testKeyId" + + TestUserId = "testUserId" + TestAdminId = "testAdminId" + TestOwnerId = "testOwnerId" + TestServiceUserId = "testServiceUserId" + TestServiceAdminId = "testServiceAdminId" + BlockedUserId = "blockedUserId" + OtherUserId = "otherUserId" + InvalidToken = "invalidToken" + + NewKeyName = "newKey" + NewGroupId = "newGroupId" + ExpiresIn = 3600 + RevokedKeyId = "revokedKeyId" + ExpiredKeyId = "expiredKeyId" + + ExistingKeyName = "existingKey" +) + +type TB interface { + Cleanup(func()) + Helper() + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + TempDir() string +} + +// BenchmarkCase defines a single benchmark test case +type BenchmarkCase struct { + Peers int + Groups int + Users int + SetupKeys int +} + +// PerformanceMetrics holds the performance expectations +type PerformanceMetrics struct { + MinMsPerOpLocal float64 + MaxMsPerOpLocal float64 + MinMsPerOpCICD float64 + MaxMsPerOpCICD float64 +} + +func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, server.AccountManager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) + done := make(chan struct{}) + if validateUpdate { + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + } + + geoMock := &geolocation.Mock{} + validatorMock := server.MocIntegratedValidator{} + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { + t.Helper() + + req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody)) + req.Header.Set("Authorization", "Bearer "+user) + + return req +} + +func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) { + t.Helper() + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if !expectResponse { + return nil, false + } + + if status := recorder.Code; status != expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v, content: %s", + status, expectedStatus, string(content)) + } + + return content, expectedStatus == http.StatusOK +} + +func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { + b.Helper() + + ctx := context.Background() + account, err := am.GetAccount(ctx, TestAccountId) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + + // Create peers + for i := 0; i < peers; i++ { + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("oldpeer-%d", i), + DNSLabel: fmt.Sprintf("oldpeer-%d", i), + Key: peerKey.PublicKey().String(), + IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, + UserID: TestUserId, + } + account.Peers[peer.ID] = peer + } + + // Create users + for i := 0; i < users; i++ { + user := &types.User{ + Id: fmt.Sprintf("olduser-%d", i), + AccountID: account.Id, + Role: types.UserRoleUser, + } + account.Users[user.Id] = user + } + + for i := 0; i < setupKeys; i++ { + key := &types.SetupKey{ + Id: fmt.Sprintf("oldkey-%d", i), + AccountID: account.Id, + AutoGroups: []string{"someGroupID"}, + UpdatedAt: time.Now().UTC(), + ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), + Name: NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + account.SetupKeys[key.Id] = key + } + + // Create groups and policies + account.Policies = make([]*types.Policy, 0, groups) + for i := 0; i < groups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + Name: fmt.Sprintf("Group %d", i), + } + for j := 0; j < peers/groups; j++ { + peerIndex := i*(peers/groups) + j + group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) + } + account.Groups[groupID] = group + + // Create a policy for this group + policy := &types.Policy{ + ID: fmt.Sprintf("policy-%d", i), + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-%d", i), + Name: fmt.Sprintf("Rule for Group %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{groupID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, policy) + } + + account.PostureChecks = []*posture.Checks{ + { + ID: "PostureChecksAll", + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + } + + err = am.Store.SaveAccount(context.Background(), account) + if err != nil { + b.Fatalf("Failed to save account: %v", err) + } + +} + +func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) { + b.Helper() + + if recorder.Code != http.StatusOK { + b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code) + } + + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := perfMetrics.MinMsPerOpLocal + maxExpected := perfMetrics.MaxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = perfMetrics.MinMsPerOpCICD + maxExpected = perfMetrics.MaxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected) + } +} diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 603c1c696..3d7eed498 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -14,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// EmptyObject is an empty struct used to return empty JSON object +type EmptyObject struct { +} + type ErrorResponse struct { Message string `json:"message"` Code int `json:"code"` diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 1692507da..dcb9400f6 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,11 +4,12 @@ import ( "context" "errors" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. @@ -59,9 +60,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) if err != nil { return err } @@ -77,37 +78,70 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { var err error - var groups []*nbgroup.Group + var groups []*types.Group var peers []*nbpeer.Peer - var settings *Settings + var settings *types.Settings - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } - peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID) + peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } - settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return err }) if err != nil { return nil, err } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) + return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) +} + +type MocIntegratedValidator struct { + ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) +} + +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + if a.ValidatePeerFunc != nil { + return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) + } + return update, false, nil +} + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + // just a dummy +} + +func (MocIntegratedValidator) Stop(_ context.Context) { + // just a dummy } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 03be9d039..ff179e3c0 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -4,8 +4,8 @@ import ( "context" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" ) // IntegratedValidator interface exists to avoid the circle dependencies @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index c441650e9..18214b434 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -22,6 +22,8 @@ const ( LastLoginSuffix = "nb_last_login" // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation Invited = "nb_invited" + // IsToken claim indicates that auth type from the user is a token + IsToken = "is_token" ) // ExtractClaims Extract function type diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d5c1e7c9e..79e59e76f 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -72,13 +72,19 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// JWTValidator struct to handle token validation and parsing -type JWTValidator struct { +type JWTValidator interface { + ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) +} + +// jwtValidatorImpl struct to handle token validation and parsing +type jwtValidatorImpl struct { options Options } +var keyNotFound = errors.New("unable to find appropriate key") + // NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err @@ -124,12 +130,18 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, } publicKey, err := getPublicKey(ctx, token, keys) - if err != nil { - log.WithContext(ctx).Errorf("getPublicKey error: %s", err) - return nil, err + if err == nil { + return publicKey, nil } - return publicKey, nil + msg := fmt.Sprintf("getPublicKey error: %s", err) + if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled { + msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) + } + + log.WithContext(ctx).Error(msg) + + return nil, err }, EnableAuthOnOptions: false, } @@ -138,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, options.UserProperty = "user" } - return &JWTValidator{ + return &jwtValidatorImpl{ options: options, }, nil } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required @@ -229,7 +241,7 @@ func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{ log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) } - return nil, errors.New("unable to find appropriate key") + return nil, keyNotFound } func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { @@ -311,3 +323,27 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } +type JwtValidatorMock struct{} + +func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + claimMaps := jwt.MapClaims{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + claimMaps[UserIDClaim] = token + claimMaps[AccountIDSuffix] = "testAccountId" + claimMaps[DomainIDSuffix] = "test.com" + claimMaps[DomainCategorySuffix] = "private" + case "otherUserId": + claimMaps[UserIDClaim] = "otherUserId" + claimMaps[AccountIDSuffix] = "otherAccountId" + claimMaps[DomainIDSuffix] = "other.com" + claimMaps[DomainCategorySuffix] = "private" + case "invalidToken": + return nil, errors.New("invalid token") + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + return jwtToken, nil +} + diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 57ad968b3..0df2462f4 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -23,7 +23,10 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -413,7 +416,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -437,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { return nil, nil, "", cleanup, err } @@ -472,8 +475,14 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie func Test_SyncStatusRace(t *testing.T) { t.Skip() - if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { - t.Skip("Skipping on CI and Postgres store") + if os.Getenv("CI") == "true" { + if os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { + t.Skip("Skipping on CI and Postgres store") + } + + if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" { + t.Skip("Skipping on CI and MySQL store") + } } for i := 0; i < 500; i++ { t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) { @@ -618,7 +627,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -705,7 +714,7 @@ func Test_LoginPerformance(t *testing.T) { return } - setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) if err != nil { t.Logf("error creating setup key: %v", err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 5361da53f..cfa2c138f 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -21,10 +21,9 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -446,43 +445,6 @@ var _ = Describe("Management service", func() { }) }) -type MocIntegratedValidator struct { -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - return update, false, nil -} - -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for p := range peers { - validatedPeers[p] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) {} - func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -532,7 +494,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } @@ -545,13 +507,13 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. log.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 843fa575e..03cb21af1 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -15,7 +15,8 @@ import ( "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -47,8 +48,8 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts(ctx context.Context) []*server.Account - GetStoreEngine() server.StoreEngine + GetAllAccounts(ctx context.Context) []*types.Account + GetStoreEngine() store.Engine } // ConnManager peer connection manager that holds state for current active connections @@ -194,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { groups int routes int routesWithRGGroups int + networks int + networkResources int + networkRouters int + networkRoutersWithPG int nameservers int uiClient int version string @@ -218,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties { } groups += len(account.Groups) + networks += len(account.Networks) + networkResources += len(account.NetworkResources) + + networkRouters += len(account.NetworkRouters) + for _, router := range account.NetworkRouters { + if len(router.PeerGroups) > 0 { + networkRoutersWithPG++ + } + } + routes += len(account.Routes) for _, route := range account.Routes { if len(route.PeerGroups) > 0 { @@ -311,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["posture_checks"] = postureChecks metricsProperties["groups"] = groups + metricsProperties["networks"] = networks + metricsProperties["network_resources"] = networkResources + metricsProperties["network_routers"] = networkRouters + metricsProperties["network_routers_with_groups"] = networkRoutersWithPG metricsProperties["routes"] = routes metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["nameservers"] = nameservers diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 2ac2d68a0..4894c1ac4 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,10 +5,13 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/group" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -22,19 +25,19 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { - return []*server.Account{ +func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{ { Id: "1", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -49,20 +52,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, SourcePostureChecks: []string{"1"}, @@ -94,16 +97,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -111,15 +114,15 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, { Id: "2", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -134,20 +137,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, @@ -158,27 +161,52 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { PeerGroups: make([]string, 1), }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, }, + Networks: []*networkTypes.Network{ + { + ID: "1", + AccountID: "1", + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + { + ID: "2", + AccountID: "1", + NetworkID: "1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + }, }, } } // GetStoreEngine returns FileStoreEngine -func (mockDatasource) GetStoreEngine() server.StoreEngine { - return server.FileStoreEngine +func (mockDatasource) GetStoreEngine() store.Engine { + return store.FileStoreEngine } // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties @@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) { if properties["routes"] != 2 { t.Errorf("expected 2 routes, got %d", properties["routes"]) } + if properties["networks"] != 1 { + t.Errorf("expected 1 networks, got %d", properties["networks"]) + } + if properties["network_resources"] != 2 { + t.Errorf("expected 2 network_resources, got %d", properties["network_resources"]) + } + if properties["network_routers"] != 1 { + t.Errorf("expected 1 network_routers, got %d", properties["network_routers"]) + } if properties["rules"] != 4 { t.Errorf("expected 4 rules, got %d", properties["rules"]) } @@ -267,7 +304,7 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } - if properties["store_engine"] != server.FileStoreEngine { + if properties["store_engine"] != store.FileStoreEngine { t.Errorf("expected JsonFile, got %s", properties["store_engine"]) } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 6f12d94b4..8986d77b5 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -17,12 +17,19 @@ import ( "gorm.io/gorm" ) +func GetColumnName(db *gorm.DB, column string) string { + if db.Name() == "mysql" { + return fmt.Sprintf("`%s`", column) + } + return column +} + // MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. // S is the type of the field to be migrated. func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error { - - oldColumnName := fieldName + orgColumnName := fieldName + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := fieldName + "_tmp" var model T @@ -72,7 +79,7 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f for _, row := range rows { var field S - str, ok := row[oldColumnName].(string) + str, ok := row[orgColumnName].(string) if !ok { return fmt.Errorf("type assertion failed") } @@ -111,7 +118,8 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f // MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error { - oldColumnName := fieldName + orgColumnName := fieldName + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := fieldName + "_tmp" var model T @@ -163,7 +171,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi for _, row := range rows { var blobValue string - if columnValue := row[oldColumnName]; columnValue != nil { + if columnValue := row[orgColumnName]; columnValue != nil { value, ok := columnValue.(string) if !ok { return fmt.Errorf("type assertion failed") @@ -210,7 +218,8 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi } func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error { - oldColumnName := "key" + orgColumnName := "key" + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := "key_secret" var model T @@ -250,8 +259,9 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er } for _, row := range rows { + var plainKey string - if columnValue := row[oldColumnName]; columnValue != nil { + if columnValue := row[orgColumnName]; columnValue != nil { value, ok := columnValue.(string) if !ok { return fmt.Errorf("type assertion failed") @@ -295,3 +305,53 @@ func hiddenKey(key string, length int) string { } return prefix + strings.Repeat("*", length) } + +func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string, defaultValue any) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&model, columnName) { + log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", columnName, tableName) + if err := tx.Migrator().AddColumn(&model, columnName); err != nil { + return fmt.Errorf("add column %s: %w", columnName, err) + } + } + + var rows []map[string]any + if err := tx.Table(tableName). + Select("id", columnName). + Where(columnName + " IS NULL OR " + columnName + " = ''"). + Find(&rows).Error; err != nil { + return fmt.Errorf("failed to find rows with empty %s: %w", columnName, err) + } + + if len(rows) == 0 { + log.WithContext(ctx).Infof("No rows with empty %s found in table %s, no migration needed", columnName, tableName) + return nil + } + + for _, row := range rows { + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, defaultValue).Error; err != nil { + return fmt.Errorf("failed to update row with id %v: %w", row["id"], err) + } + } + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName) + return nil +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 51358c7ad..a645ae325 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -12,9 +12,9 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") type network struct { - server.Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } type account struct { - server.Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` } - err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error + err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert Gob data") var gobStr string - err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error + err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error assert.NoError(t, err, "Failed to fetch Gob data") err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated") } func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") - err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error + err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged") } @@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") type location struct { @@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { } type account struct { - server.Account + types.Account Peers []peer `gorm:"foreignKey:AccountID;references:id"` } err = db.Save(&account{ - Account: server.Account{Id: "123"}, + Account: types.Account{Id: "123"}, Peers: []peer{ {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, @@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.Account{ + err = db.Save(&types.Account{ Id: "1234", PeersG: []nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, @@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", KeySecret: "EEFDA****", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing. func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3e465e32e..bcb7f0642 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,48 +13,48 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -63,35 +63,35 @@ type MockAccountManager struct { SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -106,12 +106,16 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error } +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) @@ -119,7 +123,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } @@ -145,7 +149,7 @@ func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID s } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(ctx, accountId, groupID, userID) } @@ -153,7 +157,7 @@ func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(ctx, accountID, userID) } @@ -161,7 +165,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { return am.GetUsersFromAccountFunc(ctx, accountID, userID) } @@ -179,7 +183,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { +) (*types.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } @@ -194,13 +198,13 @@ func (am *MockAccountManager) CreateSetupKey( ctx context.Context, accountID string, keyName string, - keyType server.SetupKeyType, + keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, -) (*server.SetupKey, error) { +) (*types.SetupKey, error) { if am.CreateSetupKeyFunc != nil { return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } @@ -235,7 +239,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str } // GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, pat string) (*server.User, *server.PersonalAccessToken, string, string, error) { +func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) { if am.GetAccountInfoFromPATFunc != nil { return am.GetAccountInfoFromPATFunc(ctx, pat) } @@ -259,7 +263,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } @@ -275,7 +279,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if am.GetPATFunc != nil { return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } @@ -283,7 +287,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } @@ -291,7 +295,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) { if am.GetNetworkMapFunc != nil { return am.GetNetworkMapFunc(ctx, peerKey) } @@ -299,7 +303,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) { if am.GetPeerNetworkFunc != nil { return am.GetPeerNetworkFunc(ctx, peerKey) } @@ -312,7 +316,7 @@ func (am *MockAccountManager) AddPeer( setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { return am.AddPeerFunc(ctx, setupKey, userId, peer) } @@ -320,7 +324,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(ctx, accountID, groupName) } @@ -328,7 +332,7 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(ctx, accountID, userID, group) } @@ -336,7 +340,7 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { if am.SaveGroupsFunc != nil { return am.SaveGroupsFunc(ctx, accountID, userID, groups) } @@ -384,7 +388,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { if am.GetPolicyFunc != nil { return am.GetPolicyFunc(ctx, accountID, policyID, userID) } @@ -392,7 +396,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { if am.SavePolicyFunc != nil { return am.SavePolicyFunc(ctx, accountID, userID, policy) } @@ -408,7 +412,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { if am.ListPoliciesFunc != nil { return am.ListPoliciesFunc(ctx, accountID, userID) } @@ -424,14 +428,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { if am.GetUserFunc != nil { return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { if am.ListUsersFunc != nil { return am.ListUsersFunc(ctx, accountID) } @@ -487,7 +491,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) { if am.SaveSetupKeyFunc != nil { return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } @@ -496,7 +500,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { if am.GetSetupKeyFunc != nil { return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } @@ -505,7 +509,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { if am.ListSetupKeysFunc != nil { return am.ListSetupKeysFunc(ctx, accountID, userID) } @@ -514,7 +518,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) { if am.SaveUserFunc != nil { return am.SaveUserFunc(ctx, accountID, userID, user) } @@ -522,7 +526,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) { if am.SaveOrAddUserFunc != nil { return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } @@ -530,7 +534,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user } // SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if am.SaveOrAddUsersFunc != nil { return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists) } @@ -601,7 +605,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { if am.CreateUserFunc != nil { return am.CreateUserFunc(ctx, accountID, userID, invite) } @@ -648,7 +652,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { return am.GetDNSSettingsFunc(ctx, accountID, userID) } @@ -656,7 +660,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } @@ -672,7 +676,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } @@ -680,7 +684,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { return am.LoginPeerFunc(ctx, login) } @@ -688,7 +692,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, accountID) } @@ -809,7 +813,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } // GetAccountByID mocks GetAccountByID of the AccountManager interface -func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { if am.GetAccountByIDFunc != nil { return am.GetAccountByIDFunc(ctx, accountID, userID) } @@ -817,21 +821,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri } // GetUserByID mocks GetUserByID of the AccountManager interface -func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { return am.GetUserByIDFunc(ctx, id) } return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") } -func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { if am.GetAccountSettingsFunc != nil { return am.GetAccountSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") } -func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { if am.GetAccountFunc != nil { return am.GetAccountFunc(ctx, accountID) } @@ -839,7 +843,7 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } // GetPeerGroups mocks GetPeerGroups of the AccountManager interface -func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { if am.GetPeerGroupsFunc != nil { return am.GetPeerGroupsFunc(ctx, accountID, peerID) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e7a5387a1..1a01c7a89 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,15 +11,16 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.NewAdminPermissionError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -64,21 +65,21 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) }) if err != nil { return nil, err @@ -87,7 +88,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return newNSGroup.Copy(), nil @@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) }) if err != nil { return err @@ -142,7 +143,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -165,22 +166,22 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco var nsGroup *nbdns.NameServerGroup var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) if err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) }) if err != nil { return err @@ -189,7 +190,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) } -func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { +func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err @@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } @@ -243,12 +244,12 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) if err != nil { return false, err } @@ -257,7 +258,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store return true, nil } - return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) + return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -305,7 +306,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups map[string]*types.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..0743db513 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,9 +11,10 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -772,10 +773,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createNSStore(t *testing.T) (Store, error) { +func createNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -784,7 +785,7 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, @@ -842,12 +843,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &nbgroup.Group{ + newGroup1 := &types.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &nbgroup.Group{ + newGroup2 := &types.Group{ ID: group2ID, Name: group2ID, } @@ -944,7 +945,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go new file mode 100644 index 000000000..51205f1e9 --- /dev/null +++ b/management/server/networks/manager.go @@ -0,0 +1,214 @@ +package networks + +import ( + "context" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) + CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) + UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error +} + +type managerImpl struct { + store store.Store + accountManager s.AccountManager + permissionsManager permissions.Manager + resourcesManager resources.Manager + routersManager routers.Manager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + resourcesManager: resourceManager, + routersManager: routersManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + network.ID = xid.New().String() + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + if err != nil { + return nil, fmt.Errorf("failed to save network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) + + return network, nil +} + +func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) + + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) +} + +func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get resources in network: %w", err) + } + + for _, resource := range resources { + event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + eventsToStore = append(eventsToStore, event...) + } + + routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get routers in network: %w", err) + } + + for _, router := range routers { + event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID) + if err != nil { + return fmt.Errorf("failed to delete router: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + return []*types.Network{}, nil +} + +func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + return &types.Network{}, nil +} + +func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + return nil +} diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go new file mode 100644 index 000000000..edd830c25 --- /dev/null +++ b/management/server/networks/manager_test.go @@ -0,0 +1,254 @@ +package networks + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllNetworksReturnsNetworks(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, networks, 1) + require.Equal(t, "testNetworkId", networks[0].ID) +} + +func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.Error(t, err) + require.Nil(t, networks) +} + +func Test_GetNetworkReturnsNetwork(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Equal(t, "testNetworkId", networks.ID) +} + +func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + network, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Nil(t, network) +} + +func Test_CreateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, createdNetwork.Name) +} + +func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, createdNetwork) +} + +func Test_DeleteNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) +} + +func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) +} + +func Test_UpdateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, updatedNetwork.Name) +} + +func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, updatedNetwork) +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go new file mode 100644 index 000000000..725d15496 --- /dev/null +++ b/management/server/networks/resources/manager.go @@ -0,0 +1,422 @@ +package resources + +import ( + "context" + "errors" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" +) + +type Manager interface { + GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) + GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) + GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) + CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) + UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error + DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + groupsManager groups.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + groupsManager: groupsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network resources: %w", err) + } + + resourceMap := make(map[string][]string) + for _, resource := range resources { + resourceMap[resource.NetworkID] = append(resourceMap[resource.NetworkID], resource.ID) + } + + return resourceMap, nil +} + +func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs, resource.Enabled) + if err != nil { + return nil, fmt.Errorf("failed to create new network resource: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil { + return errors.New("resource already exists") + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network)) + } + eventsToStore = append(eventsToStore, event) + + res := nbtypes.Resource{ + ID: resource.ID, + Type: resource.Type.String(), + } + for _, groupID := range resource.GroupIDs { + event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) + if err != nil { + return fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to create network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + return resource, nil +} + +func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resourceType, domain, prefix, err := types.GetResourceType(resource.Address) + if err != nil { + return nil, fmt.Errorf("failed to get resource type: %w", err) + } + + resource.Type = resourceType + resource.Domain = domain + resource.Prefix = prefix + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != resource.NetworkID { + return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) + } + + _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil && oldResource.ID != resource.ID { + return errors.New("new resource name already exists") + } + + oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + events, err := m.updateResourceGroups(ctx, transaction, userID, resource, oldResource) + if err != nil { + return fmt.Errorf("failed to update resource groups: %w", err) + } + + eventsToStore = append(eventsToStore, events...) + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) + }) + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to update network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { + res := nbtypes.Resource{ + ID: newResource.ID, + Type: newResource.Type.String(), + } + + oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + oldGroupsIds := make([]string, 0) + for _, group := range oldResourceGroups { + oldGroupsIds = append(oldGroupsIds, group.ID) + } + + var eventsToStore []func() + groupsToAdd := util.Difference(newResource.GroupIDs, oldGroupsIds) + for _, groupID := range groupsToAdd { + events, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, &res) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + groupsToRemove := util.Difference(oldGroupsIds, newResource.GroupIDs) + for _, groupID := range groupsToRemove { + events, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, res.ID) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + return eventsToStore, nil +} + +func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var events []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network resource: %w", err) + } + + for _, event := range events { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + var eventsToStore []func() + + for _, group := range groups { + event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, userID, group.ID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to remove resource from group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to delete network resource: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network)) + }) + + return eventsToStore, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + return map[string][]string{}, nil +} + +func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + return nil +} + +func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + return []func(){}, nil +} diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go new file mode 100644 index 000000000..993cd65df --- /dev/null +++ b/management/server/networks/resources/manager_test.go @@ -0,0 +1,411 @@ +package resources + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} +func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_GetResourceInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, resourceID, resource.ID) +} + +func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_CreateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "newResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.NoError(t, err) + require.Equal(t, resource.Name, createdResource.Name) +} + +func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithUsedName(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_UpdateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: "someNewName", + ID: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.NoError(t, err) + require.NotNil(t, updatedResource) + require.Equal(t, "new-description", updatedResource.Description) + require.Equal(t, "1.2.3.0/24", updatedResource.Address) + require.Equal(t, types.NetworkResourceType("subnet"), updatedResource.Type) +} + +func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "otherResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + ID: resourceID, + Name: "used-name", + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_DeleteResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) +} + +func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) +} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go new file mode 100644 index 000000000..0df6727c3 --- /dev/null +++ b/management/server/networks/resources/types/resource.go @@ -0,0 +1,175 @@ +package types + +import ( + "errors" + "fmt" + "net/netip" + "regexp" + + "github.com/rs/xid" + + nbDomain "github.com/netbirdio/netbird/management/domain" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type NetworkResourceType string + +const ( + host NetworkResourceType = "host" + subnet NetworkResourceType = "subnet" + domain NetworkResourceType = "domain" +) + +func (p NetworkResourceType) String() string { + return string(p) +} + +type NetworkResource struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string + Type NetworkResourceType + Address string `gorm:"-"` + GroupIDs []string `gorm:"-"` + Domain string + Prefix netip.Prefix `gorm:"serializer:json"` + Enabled bool +} + +func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string, enabled bool) (*NetworkResource, error) { + resourceType, domain, prefix, err := GetResourceType(address) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + + return &NetworkResource{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Name: name, + Description: description, + Type: resourceType, + Address: address, + Domain: domain, + Prefix: prefix, + GroupIDs: groupIDs, + Enabled: enabled, + }, nil +} + +func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkResource { + addr := n.Prefix.String() + if n.Type == domain { + addr = n.Domain + } + + return &api.NetworkResource{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Type: api.NetworkResourceType(n.Type.String()), + Address: addr, + Groups: groups, + Enabled: n.Enabled, + } +} + +func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { + n.Name = req.Name + + if req.Description != nil { + n.Description = *req.Description + } + n.Address = req.Address + n.GroupIDs = req.Groups + n.Enabled = req.Enabled +} + +func (n *NetworkResource) Copy() *NetworkResource { + return &NetworkResource{ + ID: n.ID, + AccountID: n.AccountID, + NetworkID: n.NetworkID, + Name: n.Name, + Description: n.Description, + Type: n.Type, + Address: n.Address, + Domain: n.Domain, + Prefix: n.Prefix, + GroupIDs: n.GroupIDs, + Enabled: n.Enabled, + } +} + +func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(fmt.Sprintf("%s:%s", n.ID, peer.ID)), + AccountID: n.AccountID, + KeepRoute: true, + NetID: route.NetID(n.Name), + Description: n.Description, + Peer: peer.Key, + PeerID: peer.ID, + PeerGroups: nil, + Masquerade: router.Masquerade, + Metric: router.Metric, + Enabled: n.Enabled, + Groups: nil, + AccessControlGroups: nil, + } + + if n.Type == host || n.Type == subnet { + r.Network = n.Prefix + + r.NetworkType = route.IPv4Network + if n.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if n.Type == domain { + domainList, err := nbDomain.FromStringList([]string{n.Domain}) + if err != nil { + return nil + } + r.Domains = domainList + r.NetworkType = route.DomainNetwork + + // add default placeholder for domain network + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + + return r +} + +func (n *NetworkResource) EventMeta(network *networkTypes.Network) map[string]any { + return map[string]any{"name": n.Name, "type": n.Type, "network_name": network.Name, "network_id": network.ID} +} + +// GetResourceType returns the type of the resource based on the address +func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { + if prefix, err := netip.ParsePrefix(address); err == nil { + if prefix.Bits() == 32 || prefix.Bits() == 128 { + return host, "", prefix, nil + } + return subnet, "", prefix, nil + } + + if ip, err := netip.ParseAddr(address); err == nil { + return host, "", netip.PrefixFrom(ip, ip.BitLen()), nil + } + + domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) + if domainRegex.MatchString(address) { + return domain, address, netip.Prefix{}, nil + } + + return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") +} diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go new file mode 100644 index 000000000..6af384cce --- /dev/null +++ b/management/server/networks/resources/types/resource_test.go @@ -0,0 +1,53 @@ +package types + +import ( + "net/netip" + "testing" +) + +func TestGetResourceType(t *testing.T) { + tests := []struct { + input string + expectedType NetworkResourceType + expectedErr bool + expectedDomain string + expectedPrefix netip.Prefix + }{ + // Valid host IPs + {"1.1.1.1", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + {"1.1.1.1/32", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + // Valid subnets + {"192.168.1.0/24", subnet, false, "", netip.MustParsePrefix("192.168.1.0/24")}, + {"10.0.0.0/16", subnet, false, "", netip.MustParsePrefix("10.0.0.0/16")}, + // Valid domains + {"example.com", domain, false, "example.com", netip.Prefix{}}, + {"*.example.com", domain, false, "*.example.com", netip.Prefix{}}, + {"sub.example.com", domain, false, "sub.example.com", netip.Prefix{}}, + // Invalid inputs + {"invalid", "", true, "", netip.Prefix{}}, + {"1.1.1.1/abc", "", true, "", netip.Prefix{}}, + {"1234", "", true, "", netip.Prefix{}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, domain, prefix, err := GetResourceType(tt.input) + + if result != tt.expectedType { + t.Errorf("Expected type %v, got %v", tt.expectedType, result) + } + + if tt.expectedErr && err == nil { + t.Errorf("Expected error, got nil") + } + + if prefix != tt.expectedPrefix { + t.Errorf("Expected address %v, got %v", tt.expectedPrefix, prefix) + } + + if domain != tt.expectedDomain { + t.Errorf("Expected domain %v, got %v", tt.expectedDomain, domain) + } + }) + } +} diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go new file mode 100644 index 000000000..3b32810a2 --- /dev/null +++ b/management/server/networks/routers/manager.go @@ -0,0 +1,289 @@ +package routers + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) + GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) + CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) + UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error + DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network routers: %w", err) + } + + routersMap := make(map[string][]*types.NetworkRouter) + for _, router := range routers { + routersMap[router.NetworkID] = append(routersMap[router.NetworkID], router) + } + + return routersMap, nil +} + +func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewNetworkNotFoundError(router.NetworkID) + } + + router.ID = xid.New().String() + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to create network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, errors.New("router not part of network") + } + + return router, nil +} + +func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) + } + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to update network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var event func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) + if err != nil { + return fmt.Errorf("failed to delete network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + event() + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) + } + + err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to delete network router: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network)) + } + + return event, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + return []*types.NetworkRouter{}, nil +} + +func (m *mockManager) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + return map[string][]*types.NetworkRouter{}, nil +} + +func (m *mockManager) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + return &types.NetworkRouter{}, nil +} + +func (m *mockManager) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + return nil +} + +func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + return func() {}, nil +} diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go new file mode 100644 index 000000000..47f5ad7e3 --- /dev/null +++ b/management/server/networks/routers/manager_test.go @@ -0,0 +1,234 @@ +package routers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, routers, 1) + require.Equal(t, "testRouterId", routers[0].ID) +} + +func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, routers) +} + +func Test_GetRouterReturnsRouter(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, "testRouterId", router.ID) +} + +func Test_GetRouterReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, router) +} + +func Test_CreateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.NoError(t, err) + require.NotEqual(t, "", router.ID) + require.Equal(t, router.NetworkID, createdRouter.NetworkID) + require.Equal(t, router.Peer, createdRouter.Peer) + require.Equal(t, router.Metric, createdRouter.Metric) + require.Equal(t, router.Masquerade, createdRouter.Masquerade) +} + +func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdRouter) +} + +func Test_DeleteRouterSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.NoError(t, err) +} + +func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) +} + +func Test_UpdateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.NoError(t, err) + require.Equal(t, router.Metric, updatedRouter.Metric) +} + +func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, updatedRouter) +} diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go new file mode 100644 index 000000000..5158ebb12 --- /dev/null +++ b/management/server/networks/routers/types/router.go @@ -0,0 +1,80 @@ +package types + +import ( + "errors" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/networks/types" +) + +type NetworkRouter struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Peer string + PeerGroups []string `gorm:"serializer:json"` + Masquerade bool + Metric int + Enabled bool +} + +func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { + if peer != "" && len(peerGroups) > 0 { + return nil, errors.New("peer and peerGroups cannot be set at the same time") + } + + return &NetworkRouter{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Peer: peer, + PeerGroups: peerGroups, + Masquerade: masquerade, + Metric: metric, + Enabled: enabled, + }, nil +} + +func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { + return &api.NetworkRouter{ + Id: n.ID, + Peer: &n.Peer, + PeerGroups: &n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + Enabled: n.Enabled, + } +} + +func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) { + if req.Peer != nil { + n.Peer = *req.Peer + } + + if req.PeerGroups != nil { + n.PeerGroups = *req.PeerGroups + } + + n.Masquerade = req.Masquerade + n.Metric = req.Metric + n.Enabled = req.Enabled +} + +func (n *NetworkRouter) Copy() *NetworkRouter { + return &NetworkRouter{ + ID: n.ID, + NetworkID: n.NetworkID, + AccountID: n.AccountID, + Peer: n.Peer, + PeerGroups: n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + Enabled: n.Enabled, + } +} + +func (n *NetworkRouter) EventMeta(network *types.Network) map[string]any { + return map[string]any{"network_name": network.Name, "network_id": network.ID, "peer": n.Peer, "peer_groups": n.PeerGroups} +} diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go new file mode 100644 index 000000000..5801e3bfa --- /dev/null +++ b/management/server/networks/routers/types/router_test.go @@ -0,0 +1,109 @@ +package types + +import "testing" + +func TestNewNetworkRouter(t *testing.T) { + tests := []struct { + name string + accountID string + networkID string + peer string + peerGroups []string + masquerade bool + metric int + enabled bool + expectedError bool + }{ + // Valid cases + { + name: "Valid with peer only", + networkID: "network-1", + accountID: "account-1", + peer: "peer-1", + peerGroups: nil, + masquerade: true, + metric: 100, + enabled: true, + expectedError: false, + }, + { + name: "Valid with peerGroups only", + networkID: "network-2", + accountID: "account-2", + peer: "", + peerGroups: []string{"group-1", "group-2"}, + masquerade: false, + metric: 200, + enabled: false, + expectedError: false, + }, + { + name: "Valid with no peer or peerGroups", + networkID: "network-3", + accountID: "account-3", + peer: "", + peerGroups: nil, + masquerade: true, + metric: 300, + enabled: true, + expectedError: false, + }, + + // Invalid cases + { + name: "Invalid with both peer and peerGroups", + networkID: "network-4", + accountID: "account-4", + peer: "peer-2", + peerGroups: []string{"group-3"}, + masquerade: false, + metric: 400, + enabled: false, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric, tt.enabled) + + if tt.expectedError && err == nil { + t.Fatalf("Expected an error, got nil") + } + + if tt.expectedError == false { + if router == nil { + t.Fatalf("Expected a NetworkRouter object, got nil") + } + + if router.AccountID != tt.accountID { + t.Errorf("Expected AccountID %s, got %s", tt.accountID, router.AccountID) + } + + if router.NetworkID != tt.networkID { + t.Errorf("Expected NetworkID %s, got %s", tt.networkID, router.NetworkID) + } + + if router.Peer != tt.peer { + t.Errorf("Expected Peer %s, got %s", tt.peer, router.Peer) + } + + if len(router.PeerGroups) != len(tt.peerGroups) { + t.Errorf("Expected PeerGroups %v, got %v", tt.peerGroups, router.PeerGroups) + } + + if router.Masquerade != tt.masquerade { + t.Errorf("Expected Masquerade %v, got %v", tt.masquerade, router.Masquerade) + } + + if router.Metric != tt.metric { + t.Errorf("Expected Metric %d, got %d", tt.metric, router.Metric) + } + + if router.Enabled != tt.enabled { + t.Errorf("Expected Enabled %v, got %v", tt.enabled, router.Enabled) + } + } + }) + } +} diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go new file mode 100644 index 000000000..a4ba7b821 --- /dev/null +++ b/management/server/networks/types/network.go @@ -0,0 +1,56 @@ +package types + +import ( + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Network struct { + ID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string +} + +func NewNetwork(accountId, name, description string) *Network { + return &Network{ + ID: xid.New().String(), + AccountID: accountId, + Name: name, + Description: description, + } +} + +func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int, policyIDs []string) *api.Network { + return &api.Network{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Routers: routerIDs, + Resources: resourceIDs, + RoutingPeersCount: routingPeersCount, + Policies: policyIDs, + } +} + +func (n *Network) FromAPIRequest(req *api.NetworkRequest) { + n.Name = req.Name + if req.Description != nil { + n.Description = *req.Description + } +} + +// Copy returns a copy of a posture checks. +func (n *Network) Copy() *Network { + return &Network{ + ID: n.ID, + AccountID: n.AccountID, + Name: n.Name, + Description: n.Description, + } +} + +func (n *Network) EventMeta() map[string]any { + return map[string]any{"name": n.Name} +} diff --git a/management/server/peer.go b/management/server/peer.go index 8e368ec4e..3ddb0a22d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -12,13 +12,14 @@ import ( "time" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" @@ -56,7 +57,7 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, status.NewUserNotPartOfAccountError() } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -74,7 +75,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return []*nbpeer.Peer{}, nil } - accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) + accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -100,14 +101,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -119,17 +120,17 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var expired bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) if err != nil { return err } - settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -154,13 +155,13 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -180,14 +181,16 @@ func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocati peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) + err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) + log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + + err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { return false, err } @@ -200,7 +203,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -210,7 +213,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var peerGroupList []string var requiresPeerUpdates bool var peerLabelChanged bool @@ -218,13 +221,13 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var loginExpirationChanged bool var inactivityExpirationChanged bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) if err != nil { return err } - settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -245,7 +248,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - newLabel, err := getPeerHostLabel(update.Name, existingLabels) + newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels) if err != nil { return err } @@ -276,7 +279,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user inactivityExpirationChanged = true } - return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) }) if err != nil { return nil, err @@ -319,7 +322,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelChanged || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } else if sshChanged { + am.UpdateAccountPeer(ctx, accountID, peer.ID) } return peer, nil @@ -330,7 +335,16 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) if err != nil { return err } @@ -343,8 +357,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) if err != nil { return err } @@ -354,7 +368,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -367,14 +381,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -390,16 +404,16 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin groups[groupID] = group.Peers } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil + return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil), nil } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -415,7 +429,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -429,7 +443,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser := false if len(userID) > 0 { addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) + accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) } else { accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } @@ -449,7 +463,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } @@ -462,13 +476,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var newPeer *nbpeer.Peer var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var setupKeyID string var setupKeyName string var ephemeral bool var groupsToAdd []string if addedByUser { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) if err != nil { return fmt.Errorf("failed to get user groups: %w", err) } @@ -477,7 +491,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -494,7 +508,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name } - if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err == nil && userdata != nil { @@ -526,7 +540,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, SSHEnabled: false, SSHKey: peer.SSHKey, - LastLogin: registrationTime, + LastLogin: ®istrationTime, CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, @@ -550,38 +564,38 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) if err != nil { return fmt.Errorf("failed adding peer to All group: %w", err) } if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) if err != nil { return err } } } - err = transaction.AddPeerToAccount(ctx, newPeer) + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } if addedByUser { - err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin) + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { return fmt.Errorf("failed to update user last login: %w", err) } @@ -615,24 +629,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock = nil if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) { - takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID) +func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { + takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } - nextIp, err := AllocatePeerIP(network.Net, takenIps) + nextIp, err := types.AllocatePeerIP(network.Net, takenIps) if err != nil { return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) } @@ -641,21 +655,22 @@ func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool var updated bool var err error + var postureChecks []*posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, sync.WireGuardPubKey) if err != nil { return status.NewPeerNotRegisteredError() } if peer.UserID != "" { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return err } @@ -665,7 +680,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -686,7 +701,13 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated = peer.UpdateMetaIfNew(sync.Meta) if updated { - err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() + log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) + if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + return err + } + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) if err != nil { return err } @@ -697,40 +718,43 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } - if isStatusChanged || (updated && sync.UpdateAccountPeers) { - am.updateAccountPeers(ctx, accountID) + if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } -// LoginPeer logs in or registers a peer. -// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. - // Try registering it. - newPeer := &nbpeer.Peer{ - Key: login.WireGuardPubKey, - Meta: login.Meta, - SSHKey: login.SSHKey, - Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, - } - - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) +func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { + // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. + // Try registering it. + newPeer := &nbpeer.Peer{ + Key: login.WireGuardPubKey, + Meta: login.Meta, + SSHKey: login.SSHKey, + Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, } - log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) - return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") + return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + } + + log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) + return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") +} + +// LoginPeer logs in or registers a peer. +// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) + if err != nil { + return am.handlePeerLoginNotFound(ctx, login, err) } // when the client sends a login request with a JWT which is used to get the user ID, // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { - log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err @@ -750,14 +774,16 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) var updateRemotePeers bool var isRequiresApproval bool var isStatusChanged bool + var isPeerUpdated bool + var postureChecks []*posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return err } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -791,9 +817,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return err } - updated := peer.UpdateMetaIfNew(login.Meta) - if updated { + isPeerUpdated = peer.UpdateMetaIfNew(login.Meta) + if isPeerUpdated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } } if peer.SSHKey != login.SSHKey { @@ -802,7 +834,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if shouldStorePeer { - if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { + if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { return err } } @@ -816,20 +848,80 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, accountID) + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } +// getPeerPostureChecks returns the posture checks for the peer. +func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if len(policies) == 0 { + return nil, nil + } + + var peerPostureChecksIDs []string + + for _, policy := range policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID) + if err != nil { + return nil, err + } + + peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) + } + + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. +func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) + if err != nil { + return nil, err + } + + for _, sourceGroup := range rule.Sources { + group, ok := sourceGroups[sourceGroup] + if !ok { + return nil, fmt.Errorf("failed to check peer in policy source group") + } + + if slices.Contains(group.Peers, peerID) { + return policy.SourcePostureChecks, nil + } + } + } + return nil, nil +} + // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // and if the peer login is expired. // The NetBird client doesn't have a way to check if the peer needs login besides sending a login request // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -840,7 +932,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -852,39 +944,39 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: network.Copy(), } return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, nil, nil, err } - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID) + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -892,12 +984,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) + err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } - err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) + err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { return err } @@ -906,7 +998,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error { if peer.AddedWithSSOLogin() { if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") @@ -927,7 +1019,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error return nil } -func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { @@ -939,7 +1031,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -948,7 +1040,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.NewUserNotPartOfAccountError() } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -957,7 +1049,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return nil, err } @@ -969,12 +1061,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) - if err != nil { - return nil, err - } - - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) if err != nil { return nil, err } @@ -984,8 +1071,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -996,9 +1088,15 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -// updateAccountPeers updates all peers that belong to an account. +// UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } + start := time.Now() defer func() { if am.metrics != nil { @@ -1006,17 +1104,9 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return - } - - peers := account.GetPeers() - - approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return } @@ -1025,8 +1115,10 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account dnsCache := &DNSConfigCache{} customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() - for _, peer := range peers { + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue @@ -1038,14 +1130,14 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + postureChecks, err := am.getPeerPostureChecks(account, p.ID) if err != nil { log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1053,24 +1145,66 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } +// UpdateAccountPeer updates a single peer that belongs to an account. +// Should be called when changes need to be synced to a specific peer only. +func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { + if !am.peersUpdateManager.HasChannel(peerId) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) + return + } + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) + return + } + + peer := account.GetPeer(peerId) + if peer == nil { + log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) + return + } + + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) + return + } + + dnsCache := &DNSConfigCache{} + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + postureChecks, err := am.getPeerPostureChecks(account, peerId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) + return + } + + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) +} + // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are connected. func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithExpiry) == 0 { return 0, false } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration @@ -1101,20 +1235,20 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are not connected. func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithInactivity) == 0 { return 0, false } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration @@ -1142,12 +1276,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte // getExpiredPeers returns peers that have been expired. func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1165,12 +1299,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID // getInactivePeers returns peers that have been expired by inactivity func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1187,30 +1321,13 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID } // GetPeerGroups returns groups that the peer is part of. -func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { - return getPeerGroups(ctx, am.Store, accountID, peerID) -} - -// getPeerGroups returns the IDs of the groups that the peer is part of. -func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) { - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - peerGroups := make([]*nbgroup.Group, 0) - for _, group := range groups { - if slices.Contains(group.Peers, peerID) { - peerGroups = append(peerGroups, group) - } - } - - return peerGroups, nil +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { + return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) } // getPeerGroupIDs returns the IDs of the groups that the peer is part of. -func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) { - groups, err := getPeerGroups(ctx, transaction, accountID, peerID) +func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { + groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return nil, err } @@ -1223,13 +1340,13 @@ func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, p return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) { - dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) +func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { + dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - existingLabels := make(lookupMap) + existingLabels := make(types.LookupMap) for _, label := range dnsLabels { existingLabels[label] = struct{}{} } @@ -1238,7 +1355,7 @@ func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) { +func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) if err != nil { return false, err @@ -1248,7 +1365,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peer // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. -func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() for _, peer := range peers { @@ -1256,12 +1373,12 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Sto return nil, err } - network, err := transaction.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - if err = transaction.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil { + if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil { return nil, err } @@ -1277,7 +1394,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Sto FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 146af8861..199c7c89d 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -6,6 +6,8 @@ import ( "slices" "sort" "time" + + "github.com/netbirdio/netbird/management/server/util" ) // Peer represents a machine connected to the network. @@ -40,7 +42,7 @@ type Peer struct { InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin *time.Time // CreatedAt records the time the peer was created CreatedAt time.Time // Indicate ephemeral peer attribute @@ -222,6 +224,15 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { return true } +// GetLastLogin returns the last login time of the peer. +func (p *Peer) GetLastLogin() time.Time { + if p.LastLogin != nil { + return *p.LastLogin + } + return time.Time{} + +} + // MarkLoginExpired marks peer's status expired or not func (p *Peer) MarkLoginExpired(expired bool) { newStatus := p.Status.Copy() @@ -258,7 +269,7 @@ func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled { return false, 0 } - expiresAt := p.LastLogin.Add(expiresIn) + expiresAt := p.GetLastLogin().Add(expiresIn) now := time.Now() timeLeft := expiresAt.Sub(now) return timeLeft <= 0, timeLeft @@ -291,7 +302,7 @@ func (p *PeerStatus) Copy() *PeerStatus { // UpdateLastLogin and set login expired false func (p *Peer) UpdateLastLogin() *Peer { - p.LastLogin = time.Now().UTC() + p.LastLogin = util.ToPtr(time.Now().UTC()) newStatus := p.Status.Copy() newStatus.LoginExpired = false p.Status = newStatus diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0e30a3762..0ecd635ba 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -20,14 +20,21 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/server/util" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" nbroute "github.com/netbirdio/netbird/route" ) @@ -37,13 +44,13 @@ func TestPeer_LoginExpired(t *testing.T) { expirationEnabled bool lastLogin time.Time expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Login Expiration Disabled. Peer Login Should Not Expire", expirationEnabled: false, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -53,7 +60,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Expire", expirationEnabled: true, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -63,7 +70,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Not Expire", expirationEnabled: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -75,7 +82,7 @@ func TestPeer_LoginExpired(t *testing.T) { t.Run(c.name, func(t *testing.T) { peer := &nbpeer.Peer{ LoginExpirationEnabled: c.expirationEnabled, - LastLogin: c.lastLogin, + LastLogin: util.ToPtr(c.lastLogin), UserID: userID, } @@ -92,14 +99,14 @@ func TestPeer_SessionExpired(t *testing.T) { lastLogin time.Time connected bool expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", expirationEnabled: false, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Hour, }, @@ -110,7 +117,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -121,7 +128,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -136,7 +143,7 @@ func TestPeer_SessionExpired(t *testing.T) { } peer := &nbpeer.Peer{ InactivityExpirationEnabled: c.expirationEnabled, - LastLogin: c.lastLogin, + LastLogin: util.ToPtr(c.lastLogin), Status: peerStatus, UserID: userID, } @@ -161,7 +168,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -233,9 +240,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) } - var setupKey *SetupKey + var setupKey *types.SetupKey for _, key := range account.SetupKeys { - if key.Type == SetupKeyReusable { + if key.Type == types.SetupKeyReusable { setupKey = key } } @@ -281,8 +288,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 nbgroup.Group - group2 nbgroup.Group + group1 types.Group + group2 types.Group ) group1.ID = xid.New().String() @@ -303,16 +310,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy := &Policy{ + policy := &types.Policy{ Name: "test", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{group1.ID}, Destinations: []string{group2.ID}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -410,7 +417,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -469,9 +476,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } account.Settings.RegularUsersViewBlocked = false @@ -482,7 +489,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -567,77 +574,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { func TestDefaultAccountManager_GetPeers(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool isServiceUser bool expectedPeerCount int }{ { name: "Regular user, no limited view settings, not a service user", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 1, }, { name: "Service user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 0, }, { name: "Service user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, no limited view settings, not a service user", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin service user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin Service user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, @@ -656,12 +663,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, Role: testCase.role, IsServiceUser: testCase.isServiceUser, } - account.Policies = []*Policy{} + account.Policies = []*types.Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings err = manager.Store.SaveAccount(context.Background(), account) @@ -726,9 +733,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou regularUser := "regular_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ + account.Users[regularUser] = &types.User{ Id: regularUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } // Create peers @@ -739,17 +746,17 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou DNSLabel: fmt.Sprintf("peer-%d", i), Key: peerKey.PublicKey().String(), IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, } account.Peers[peer.ID] = peer } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) + account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) - group := &nbgroup.Group{ + group := &types.Group{ ID: groupID, Name: fmt.Sprintf("Group %d", i), } @@ -757,14 +764,95 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } + + // Create network, router and resource for this group + network := &networkTypes.Network{ + ID: fmt.Sprintf("network-%d", i), + AccountID: account.Id, + Name: fmt.Sprintf("Network for Group %d", i), + } + account.Networks = append(account.Networks, network) + + ips := account.GetTakenIPs() + peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, "", "", err + } + + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + Key: peerKey.PublicKey().String(), + IP: peerIP, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, + UserID: regularUser, + Meta: nbpeer.PeerSystemMeta{ + Hostname: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer.ID] = peer + + group.Peers = append(group.Peers, peer.ID) account.Groups[groupID] = group + router := &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("network-router-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Peer: peer.ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 9999, + } + account.NetworkRouters = append(account.NetworkRouters, router) + + resource := &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("network-resource-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Name: fmt.Sprintf("Network resource for Group %d", i), + Type: "host", + Address: "192.0.2.0/32", + } + account.NetworkResources = append(account.NetworkResources, resource) + + // Create a policy for this network resource + nrPolicy := &types.Policy{ + ID: fmt.Sprintf("policy-nr-%d", i), + Name: fmt.Sprintf("Policy for network resource %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-nr-%d", i), + Name: fmt.Sprintf("Rule for network resource %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{}, + DestinationResource: types.Resource{ + ID: resource.ID, + }, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, nrPolicy) + // Create a policy for this group - policy := &Policy{ + policy := &types.Policy{ ID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Policy for Group %d", i), Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), @@ -772,8 +860,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Sources: []string{groupID}, Destinations: []string{groupID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -833,19 +921,24 @@ func BenchmarkGetPeers(b *testing.B) { }) } } - func BenchmarkUpdateAccountPeers(b *testing.B) { benchCases := []struct { name string peers int groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 }{ - {"Small", 50, 5}, - {"Medium", 500, 10}, - {"Large", 5000, 20}, - {"Small single", 50, 1}, - {"Medium single", 500, 1}, - {"Large 5", 5000, 5}, + {"Small", 50, 5, 90, 120, 90, 120}, + {"Medium", 500, 100, 110, 150, 120, 260}, + {"Large", 5000, 200, 800, 1700, 2500, 5000}, + {"Small single", 50, 10, 90, 120, 90, 120}, + {"Medium single", 500, 10, 110, 170, 120, 200}, + {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, + {"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, } log.SetOutput(io.Discard) @@ -877,12 +970,27 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id) } duration := time.Since(start) - b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") - b.ReportMetric(0, "ns/op") + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + } }) } } @@ -920,8 +1028,8 @@ func TestToSyncResponse(t *testing.T) { Payload: "turn-user", Signature: "turn-pass", } - networkMap := &NetworkMap{ - Network: &Network{Net: *ipnet, Serial: 1000}, + networkMap := &types.NetworkMap{ + Network: &types.Network{Net: *ipnet, Serial: 1000}, Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ @@ -968,8 +1076,8 @@ func TestToSyncResponse(t *testing.T) { }, CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + FirewallRules: []*types.FirewallRule{ + {PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"}, }, } dnsName := "example.com" @@ -984,7 +1092,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, true) assert.NotNil(t, response) // assert peer config @@ -1069,7 +1177,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1080,13 +1188,13 @@ func Test_RegisterPeerByUser(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1103,18 +1211,18 @@ func Test_RegisterPeerByUser(t *testing.T) { UserID: existingUserID, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, - LastLogin: time.Now(), + LastLogin: util.ToPtr(time.Now()), } addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.UserID, existingUserID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) @@ -1125,7 +1233,7 @@ func Test_RegisterPeerByUser(t *testing.T) { lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin) + assert.NotEqual(t, lastLogin, account.Users[existingUserID].GetLastLogin()) } func Test_RegisterPeerBySetupKey(t *testing.T) { @@ -1133,7 +1241,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1144,13 +1252,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1173,11 +1281,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) @@ -1200,7 +1308,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1211,13 +1319,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1239,10 +1347,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) require.Error(t, err) - _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.Error(t, err) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.NotContains(t, account.Peers, newPeer.ID) assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) @@ -1255,7 +1363,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { hashedKey := sha256.Sum256([]byte(faultyKey)) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC()) + assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].GetLastUsed().UTC()) assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes) } @@ -1265,7 +1373,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1285,26 +1393,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) // create a user with auto groups - _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ { Id: "regularUser1", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupA"}, }, { Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupB"}, }, { Id: "regularUser3", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupC"}, }, }, true) @@ -1445,16 +1553,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go new file mode 100644 index 000000000..320aad027 --- /dev/null +++ b/management/server/permissions/manager.go @@ -0,0 +1,102 @@ +package permissions + +import ( + "context" + "errors" + "fmt" + + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" +) + +type Module string + +const ( + Networks Module = "networks" + Peers Module = "peers" + Groups Module = "groups" +) + +type Operation string + +const ( + Read Operation = "read" + Write Operation = "write" +) + +type Manager interface { + ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) +} + +type managerImpl struct { + userManager users.Manager + settingsManager settings.Manager +} + +type managerMock struct { +} + +func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager { + return &managerImpl{ + userManager: userManager, + settingsManager: settingsManager, + } +} + +func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + user, err := m.userManager.GetUser(ctx, userID) + if err != nil { + return false, err + } + + if user == nil { + return false, errors.New("user not found") + } + + if user.AccountID != accountID { + return false, errors.New("user does not belong to account") + } + + switch user.Role { + case types.UserRoleAdmin, types.UserRoleOwner: + return true, nil + case types.UserRoleUser: + return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation) + case types.UserRoleBillingAdmin: + return false, nil + default: + return false, errors.New("invalid role") + } +} + +func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + settings, err := m.settingsManager.GetSettings(ctx, accountID, userID) + if err != nil { + return false, fmt.Errorf("failed to get settings: %w", err) + } + if settings.RegularUsersViewBlocked { + return false, nil + } + + if operation == Write { + return false, nil + } + + if module == Peers { + return true, nil + } + + return false, nil +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + if userID == "allowedUser" { + return true, nil + } + return false, nil +} diff --git a/management/server/policy.go b/management/server/policy.go index 6dcb96316..45b3e93e6 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,344 +3,21 @@ package server import ( "context" _ "embed" - "strconv" - "strings" + + "github.com/rs/xid" "github.com/netbirdio/netbird/management/proto" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PolicyUpdateOperationType operation type -type PolicyUpdateOperationType int - -// PolicyTrafficActionType action type for the firewall -type PolicyTrafficActionType string - -// PolicyRuleProtocolType type of traffic -type PolicyRuleProtocolType string - -// PolicyRuleDirection direction of traffic -type PolicyRuleDirection string - -const ( - // PolicyTrafficActionAccept indicates that the traffic is accepted - PolicyTrafficActionAccept = PolicyTrafficActionType("accept") - // PolicyTrafficActionDrop indicates that the traffic is dropped - PolicyTrafficActionDrop = PolicyTrafficActionType("drop") -) - -const ( - // PolicyRuleProtocolALL type of traffic - PolicyRuleProtocolALL = PolicyRuleProtocolType("all") - // PolicyRuleProtocolTCP type of traffic - PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") - // PolicyRuleProtocolUDP type of traffic - PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") - // PolicyRuleProtocolICMP type of traffic - PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") -) - -const ( - // PolicyRuleFlowDirect allows traffic from source to destination - PolicyRuleFlowDirect = PolicyRuleDirection("direct") - // PolicyRuleFlowBidirect allows traffic to both directions - PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") -) - -const ( - // DefaultRuleName is a name for the Default rule that is created for every account - DefaultRuleName = "Default" - // DefaultRuleDescription is a description for the Default rule that is created for every account - DefaultRuleDescription = "This is a default rule that allows connections between all the resources" - // DefaultPolicyName is a name for the Default policy that is created for every account - DefaultPolicyName = "Default" - // DefaultPolicyDescription is a description for the Default policy that is created for every account - DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" -) - -const ( - firewallRuleDirectionIN = 0 - firewallRuleDirectionOUT = 1 -) - -// PolicyUpdateOperation operation object with type and values to be applied -type PolicyUpdateOperation struct { - Type PolicyUpdateOperationType - Values []string -} - -// RulePortRange represents a range of ports for a firewall rule. -type RulePortRange struct { - Start uint16 - End uint16 -} - -// PolicyRule is the metadata of the policy -type PolicyRule struct { - // ID of the policy rule - ID string `gorm:"primaryKey"` - - // PolicyID is a reference to Policy that this object belongs - PolicyID string `json:"-" gorm:"index"` - - // Name of the rule visible in the UI - Name string - - // Description of the rule visible in the UI - Description string - - // Enabled status of rule in the system - Enabled bool - - // Action policy accept or drops packets - Action PolicyTrafficActionType - - // Destinations policy destination groups - Destinations []string `gorm:"serializer:json"` - - // Sources policy source groups - Sources []string `gorm:"serializer:json"` - - // Bidirectional define if the rule is applicable in both directions, sources, and destinations - Bidirectional bool - - // Protocol type of the traffic - Protocol PolicyRuleProtocolType - - // Ports or it ranges list - Ports []string `gorm:"serializer:json"` - - // PortRanges a list of port ranges. - PortRanges []RulePortRange `gorm:"serializer:json"` -} - -// Copy returns a copy of a policy rule -func (pm *PolicyRule) Copy() *PolicyRule { - rule := &PolicyRule{ - ID: pm.ID, - PolicyID: pm.PolicyID, - Name: pm.Name, - Description: pm.Description, - Enabled: pm.Enabled, - Action: pm.Action, - Destinations: make([]string, len(pm.Destinations)), - Sources: make([]string, len(pm.Sources)), - Bidirectional: pm.Bidirectional, - Protocol: pm.Protocol, - Ports: make([]string, len(pm.Ports)), - PortRanges: make([]RulePortRange, len(pm.PortRanges)), - } - copy(rule.Destinations, pm.Destinations) - copy(rule.Sources, pm.Sources) - copy(rule.Ports, pm.Ports) - copy(rule.PortRanges, pm.PortRanges) - return rule -} - -// Policy of the Rego query -type Policy struct { - // ID of the policy' - ID string `gorm:"primaryKey"` - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name of the Policy - Name string - - // Description of the policy visible in the UI - Description string - - // Enabled status of the policy - Enabled bool - - // Rules of the policy - Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` - - // SourcePostureChecks are ID references to Posture checks for policy source groups - SourcePostureChecks []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the policy. -func (p *Policy) Copy() *Policy { - c := &Policy{ - ID: p.ID, - AccountID: p.AccountID, - Name: p.Name, - Description: p.Description, - Enabled: p.Enabled, - Rules: make([]*PolicyRule, len(p.Rules)), - SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), - } - for i, r := range p.Rules { - c.Rules[i] = r.Copy() - } - copy(c.SourcePostureChecks, p.SourcePostureChecks) - return c -} - -// EventMeta returns activity event meta related to this policy -func (p *Policy) EventMeta() map[string]any { - return map[string]any{"name": p.Name} -} - -// UpgradeAndFix different version of policies to latest version -func (p *Policy) UpgradeAndFix() { - for _, r := range p.Rules { - // start migrate from version v0.20.3 - if r.Protocol == "" { - r.Protocol = PolicyRuleProtocolALL - } - if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { - r.Bidirectional = true - } - // -- v0.20.4 - } -} - -// ruleGroups returns a list of all groups referenced in the policy's rules, -// including sources and destinations. -func (p *Policy) ruleGroups() []string { - groups := make([]string, 0) - for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) - groups = append(groups, rule.Destinations...) - } - - return groups -} - -// FirewallRule is a rule of the firewall. -type FirewallRule struct { - // PeerIP of the peer - PeerIP string - - // Direction of the traffic - Direction int - - // Action of the traffic - Action string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port string -} - -// getPeerConnectionResources for a given peer -// -// This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) - for _, policy := range a.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) - - if rule.Bidirectional { - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionIN) - } - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionOUT) - } - } - - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionOUT) - } - - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionIN) - } - } - } - - return getAccumulatedResources() -} - -// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls -// -// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. -// It safe to call the generator function multiple times for same peer and different rules no duplicates will be -// generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - rules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &nbgroup.Group{} - } - - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) - for _, peer := range groupPeers { - if peer == nil { - continue - } - - if _, ok := peersExists[peer.ID]; !ok { - peers = append(peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = "0.0.0.0" - } - - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 { - rules = append(rules, &fr) - continue - } - - for _, port := range rule.Ports { - pr := fr // clone rule and add set new port - pr.Port = port - rules = append(rules, &pr) - } - } - }, func() ([]*nbpeer.Peer, []*FirewallRule) { - return peers, rules - } -} - // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -353,15 +30,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.NewAdminPermissionError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -378,7 +55,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var updateAccountPeers bool var action = activity.PolicyAdded - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { return err } @@ -388,7 +65,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -398,7 +75,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user saveFunc = transaction.SavePolicy } - return saveFunc(ctx, LockingStrengthUpdate, policy) + return saveFunc(ctx, store.LockingStrengthUpdate, policy) }) if err != nil { return nil, err @@ -407,7 +84,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return policy, nil @@ -418,7 +95,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -431,11 +108,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return status.NewAdminPermissionError() } - var policy *Policy + var policy *types.Policy var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -445,11 +122,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) }) if err != nil { return err @@ -458,15 +135,15 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // ListPolicies from the store. -func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -479,13 +156,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) } // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return false, err } @@ -494,7 +171,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err } @@ -502,17 +179,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account if hasPeers { return true, nil } - - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return err } @@ -521,12 +196,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -534,7 +209,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po for i, rule := range policy.Rules { ruleCopy := rule.Copy() if ruleCopy.ID == "" { - ruleCopy.ID = xid.New().String() + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor ruleCopy.PolicyID = policy.ID } @@ -550,84 +225,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po return nil } -// getAllPeersFromGroups for given peer ID and list of groups -// -// Returns a list of peers from specified groups that pass specified posture checks -// and a boolean indicating if the supplied peer ID exists within these groups. -// -// Important: Posture checks are applicable only to source group peers, -// for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { - peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { - continue - } - - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) - } - } - return filteredPeers, peerInGroups -} - -// validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { - peer, ok := a.Peers[peerID] - if !ok && peer == nil { - return false - } - - for _, postureChecksID := range sourcePostureChecksID { - postureChecks := a.getPostureChecks(postureChecksID) - if postureChecks == nil { - continue - } - - for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(ctx, *peer) - if err != nil { - log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) - } - if !isValid { - return false - } - } - } - return true -} - -func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { - for _, postureChecks := range a.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks - } - } - return nil -} - // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) @@ -641,7 +238,7 @@ func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureCh } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { +func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { if _, exists := groups[id]; exists { @@ -653,7 +250,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(rules)) for i := range rules { rule := rules[i] diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 62d80f46e..73fc6edba 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -10,13 +10,13 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" ) func TestAccount_getPeersByPolicy(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -59,7 +59,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerH", }, }, + "GroupWorkstations": { + ID: "GroupWorkstations", + Name: "GroupWorkstations", + Peers: []string{ + "peerB", + "peerA", + "peerD", + "peerE", + "peerF", + "peerG", + "peerH", + }, + }, "GroupSwarm": { ID: "GroupSwarm", Name: "swarm", @@ -87,21 +100,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -116,18 +129,18 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", - "GroupAll", + "GroupWorkstations", }, Destinations: []string{ "GroupSwarm", @@ -145,75 +158,62 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) + assert.Contains(t, peers, account.Peers["peerG"]) + assert.Contains(t, peers, account.Peers["peerH"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionOUT, - Action: "accept", - Protocol: "all", - Port: "", - }, - { - PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, - Action: "accept", - Protocol: "all", - Port: "", - }, - { - PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, - Action: "accept", - Protocol: "all", - Port: "", - }, - - { - PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -221,14 +221,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -236,14 +236,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -251,14 +251,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -266,30 +266,36 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, } assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) - slices.SortFunc(firewallRules, sortFunc()) - for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + + for _, rule := range firewallRules { + contains := false + for _, expectedRule := range epectedFirewallRules { + if rule.IsEqual(expectedRule) { + contains = true + break + } + } + assert.True(t, contains, "rule not found in expected rules %#v", rule) } }) } func TestAccount_getPeersByPolicyDirect(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -307,7 +313,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -332,21 +338,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: false, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: false, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -361,15 +367,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", }, @@ -388,20 +394,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -416,20 +422,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -446,13 +452,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -467,13 +473,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -489,7 +495,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -582,7 +588,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -630,17 +636,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, } - account.Policies = append(account.Policies, &Policy{ + account.Policies = append(account.Policies, &types.Policy{ ID: "PolicyPostureChecks", Name: "", Description: "This is the policy with posture checks applied", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Enabled: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, Destinations: []string{ "GroupSwarm", }, @@ -648,7 +654,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { "GroupAll", }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"}, }, }, @@ -664,7 +670,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -674,13 +680,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", @@ -690,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -700,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -715,19 +721,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -742,14 +748,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) @@ -760,45 +766,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // assert peers from Group All assert.Contains(t, peers, account.Peers["peerC"]) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", @@ -809,8 +815,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }) } -func sortFunc() func(a *FirewallRule, b *FirewallRule) int { - return func(a, b *FirewallRule) int { +func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { + return func(a, b *types.FirewallRule) int { // Concatenate PeerIP and Direction as string for comparison aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction) bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction) @@ -829,7 +835,7 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -858,9 +864,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var policyWithGroupRulesNoPeers *Policy - var policyWithDestinationPeersOnly *Policy - var policyWithSourceAndDestinationPeers *Policy + var policyWithGroupRulesNoPeers *types.Policy + var policyWithDestinationPeersOnly *types.Policy + var policyWithSourceAndDestinationPeers *types.Policy // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { @@ -870,16 +876,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -901,17 +907,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -933,17 +939,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -965,16 +971,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d7b5a79a2..1690f8e33 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,19 +2,22 @@ package server import ( "context" + "errors" "fmt" "slices" + "github.com/rs/xid" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -27,12 +30,15 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.NewAdminPermissionError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) } // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -49,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { return err } @@ -60,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -68,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) }) if err != nil { return nil, err @@ -77,7 +83,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return postureChecks, nil @@ -85,7 +91,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -100,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun var postureChecks *posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) if err != nil { return err } @@ -110,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) }) if err != nil { return err @@ -127,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun // ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -140,57 +149,40 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) - if err != nil { - return err + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue } - if len(postureChecks) == 0 { - return nil + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err } - - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return nil, err } return maps.Values(peerPostureChecks), nil } // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups()) if err != nil { return false, err } @@ -205,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a } // validatePostureChecks validates the posture checks. -func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { +func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. if postureChecks.ID != "" { - if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { return err } return nil } // For new posture checks, ensure no duplicates by name. - checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -236,8 +228,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err } @@ -247,9 +239,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) - if err != nil { - return err + postureCheck := account.GetPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") } peerPostureChecks[sourcePostureCheckID] = postureCheck } @@ -258,17 +250,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) - if err != nil { - log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) - return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") } if slices.Contains(group.Peers, peerID) { @@ -281,8 +272,8 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI } // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. -func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf..bad162f05 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" ) @@ -92,17 +93,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ + admin := &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, } - user := &User{ + user := &types.User{ Id: regularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) @@ -120,7 +121,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -209,15 +210,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -312,15 +313,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -356,15 +357,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -395,15 +396,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -443,18 +444,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { account, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") - groupA := &group.Group{ + groupA := &types.Group{ ID: "groupA", AccountID: account.Id, Peers: []string{"peer1"}, } - groupB := &group.Group{ + groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ @@ -477,9 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") - policy := &Policy{ + policy := &types.Policy{ AccountID: account.Id, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, @@ -534,7 +535,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/resource.go b/management/server/resource.go new file mode 100644 index 000000000..77a5612b3 --- /dev/null +++ b/management/server/resource.go @@ -0,0 +1,21 @@ +package server + +type ResourceType string + +const ( + // nolint + hostType ResourceType = "Host" + //nolint + subnetType ResourceType = "Subnet" + // nolint + domainType ResourceType = "Domain" +) + +func (p ResourceType) String() string { + return string(p) +} + +type Resource struct { + Type ResourceType + ID string +} diff --git a/management/server/route.go b/management/server/route.go index ecb562645..b6b44fbbd 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,15 +4,12 @@ import ( "context" "fmt" "net/netip" - "slices" - "strconv" - "strings" "unicode/utf8" "github.com/rs/xid" - log "github.com/sirupsen/logrus" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -21,33 +18,9 @@ import ( "github.com/netbirdio/netbird/route" ) -// RouteFirewallRule a firewall rule applicable for a routed network. -type RouteFirewallRule struct { - // SourceRanges IP ranges of the routing peers. - SourceRanges []string - - // Action of the traffic when the rule is applicable - Action string - - // Destination a network prefix for the routed traffic - Destination string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port uint16 - - // PortRange represents the range of ports for a firewall rule - PortRange RulePortRange - - // isDynamic indicates whether the rule is for DNS routing - IsDynamic bool -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) @@ -238,7 +211,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if am.isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +297,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +329,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if am.isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { @@ -391,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route { } func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0) + protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) } @@ -404,187 +377,7 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - } - - return routesFirewallRules -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - IsDynamic: route.IsDynamic(), - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - -// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups -// and returns a list of policies that have rules with destinations matching the specified groups. -func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { - routePolicies := make([]*Policy, 0) - for _, groupID := range accessControlGroups { - group, ok := account.Groups[groupID] - if !ok { - continue - } - - for _, policy := range account.Policies { - for _, rule := range policy.Rules { - exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { - return groupID == group.ID - }) - if exist { - routePolicies = append(routePolicies, policy) - continue - } - } - } - } - - return routePolicies -} - -// generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { - rulesExists := make(map[string]struct{}) - rules := make([]*RouteFirewallRule, 0) - - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) - } - - baseRule := RouteFirewallRule{ - SourceRanges: sourceRanges, - Action: string(rule.Action), - Destination: route.Network.String(), - Protocol: string(rule.Protocol), - IsDynamic: route.IsDynamic(), - } - - // generate rule for port range - if len(rule.Ports) == 0 { - rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) - } else { - rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) - - } - - // TODO: generate IPv6 rules for dynamic routes - - return rules -} - -// generateRuleIDBase generates the base rule ID for checking duplicates. -func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { - return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action -} - -// generateRulesForPeer generates rules for a given peer based on ports and port ranges. -func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - - ruleIDBase := generateRuleIDBase(rule, baseRule) - if len(rule.Ports) == 0 { - if len(rule.PortRanges) == 0 { - if _, ok := rulesExists[ruleIDBase]; !ok { - rulesExists[ruleIDBase] = struct{}{} - rules = append(rules, &baseRule) - } - } else { - for _, portRange := range rule.PortRanges { - ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) - if _, ok := rulesExists[ruleID]; !ok { - rulesExists[ruleID] = struct{}{} - pr := baseRule - pr.PortRange = portRange - rules = append(rules, &pr) - } - } - } - return rules - } - - return rules -} - -// generateRulesWithPorts generates rules when specific ports are provided. -func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - ruleIDBase := generateRuleIDBase(rule, baseRule) - - for _, port := range rule.Ports { - ruleID := ruleIDBase + port - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - pr := baseRule - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) - continue - } - - pr.Port = uint16(p) - rules = append(rules, &pr) - } - - return rules -} - -func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { result := make([]*proto.RouteFirewallRule, len(rules)) for i := range rules { rule := rules[i] @@ -603,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { - if direction == firewallRuleDirectionOUT { + if direction == types.FirewallRuleDirectionOUT { return proto.RuleDirection_OUT } return proto.RuleDirection_IN @@ -611,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection { // getProtoAction converts the action to proto.RuleAction. func getProtoAction(action string) proto.RuleAction { - if action == string(PolicyTrafficActionDrop) { + if action == string(types.PolicyTrafficActionDrop) { return proto.RuleAction_DROP } return proto.RuleAction_ACCEPT @@ -619,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction { // getProtoProtocol converts the protocol to proto.RuleProtocol. func getProtoProtocol(protocol string) proto.RuleProtocol { - switch PolicyRuleProtocolType(protocol) { - case PolicyRuleProtocolALL: + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: return proto.RuleProtocol_ALL - case PolicyRuleProtocolTCP: + case types.PolicyRuleProtocolTCP: return proto.RuleProtocol_TCP - case PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolUDP: return proto.RuleProtocol_UDP - case PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolICMP: return proto.RuleProtocol_ICMP default: return proto.RuleProtocol_UNKNOWN @@ -634,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol { } // getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { var portInfo proto.PortInfo if rule.Port != 0 { portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} @@ -651,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e0..1c5c56f60 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sort" "testing" "time" @@ -12,11 +13,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -1091,9 +1097,9 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *nbgroup.Group + var groupHA1, groupHA2 *types.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -1201,7 +1207,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &nbgroup.Group{ + newGroup := &types.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1254,10 +1260,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createRouterStore(t *testing.T) (Store, error) { +func createRouterStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1266,7 +1272,7 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() accountID := "testingAcc" @@ -1278,8 +1284,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + ips := account.GetTakenIPs() + peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1301,12 +1307,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer1.ID] = peer1 - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1328,12 +1334,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer2.ID] = peer2 - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1355,12 +1361,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer3.ID] = peer3 - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1382,12 +1388,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer4.ID] = peer4 - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1409,7 +1415,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer5.ID] = peer5 @@ -1438,7 +1444,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*nbgroup.Group{ + newGroup := []*types.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1486,9 +1492,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerBIp = "100.65.80.39" peerCIp = "100.65.254.139" peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" ) - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -1541,8 +1549,18 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IP: net.ParseIP(peerHIp), Status: &nbpeer.PeerStatus{}, }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "routingPeer1": { ID: "routingPeer1", Name: "RoutingPeer1", @@ -1567,6 +1585,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Name: "Route2", Peers: []string{}, }, + "route4": { + ID: "route4", + Name: "route4", + Peers: []string{}, + }, "finance": { ID: "finance", Name: "Finance", @@ -1584,6 +1607,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { "peerB", }, }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + }, "contractors": { ID: "contractors", Name: "Contractors", @@ -1631,20 +1676,33 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Groups: []string{"contractors"}, AccessControlGroups: []string{}, }, + "route4": { + ID: "route4", + Network: netip.MustParsePrefix("192.168.10.0/16"), + NetID: "route4", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1"}, + Description: "Route4", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"qa"}, + AccessControlGroups: []string{"route4"}, + }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleRoute1", Name: "Route1", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute1", Name: "ruleRoute1", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80", "320"}, Sources: []string{ "dev", @@ -1659,15 +1717,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute2", Name: "Route2", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute2", Name: "ruleRoute2", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, - PortRanges: []RulePortRange{ + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ { Start: 80, End: 350, @@ -1685,6 +1743,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, }, }, + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{ + "restrictQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Sources: []string{ + "unrestrictedQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, }, } @@ -1695,28 +1796,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] - policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) assert.Len(t, policies, 1) route2 := account.Routes["route2"] - policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) assert.Len(t, policies, 1) route3 := account.Routes["route3"] - policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 2) + routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + assert.Len(t, routesFirewallRules, 4) - expectedRoutesFirewallRules := []*RouteFirewallRule{ + expectedRoutesFirewallRules := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1725,9 +1826,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1735,30 +1836,51 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + additionalFirewallRule := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "tcp", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "all", + }, + } + + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) assert.Len(t, routesFirewallRules, 3) - expectedRoutesFirewallRules = []*RouteFirewallRule{ + expectedRoutesFirewallRules = []*types.RouteFirewallRule{ { SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, Action: "accept", Destination: existingNetwork.String(), Protocol: "tcp", - PortRange: RulePortRange{Start: 80, End: 350}, + PortRange: types.RulePortRange{Start: 80, End: 350}, }, { SourceRanges: []string{"0.0.0.0/0"}, Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, { @@ -1766,18 +1888,27 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) assert.Len(t, routesFirewallRules, 0) }) } +// orderList is a helper function to sort a list of strings +func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { + for _, rule := range ruleList { + sort.Strings(rule.SourceRanges) + } + return ruleList +} + func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") @@ -1785,7 +1916,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1981,7 +2112,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2021,7 +2152,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, @@ -2035,3 +2166,581 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) } + +func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" + peerMIp = "100.65.29.67" + peerOIp = "100.65.29.68" + ) + + account := &types.Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Key: "peerA", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Key: "peerD", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: "peerE", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerL": { + ID: "peerL", + IP: net.ParseIP("100.65.19.186"), + Key: "peerL", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerM": { + ID: "peerM", + IP: net.ParseIP(peerMIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerN": { + ID: "peerN", + IP: net.ParseIP("100.65.20.18"), + Key: "peerN", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerO": { + ID: "peerO", + IP: net.ParseIP(peerOIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*types.Group{ + "router1": { + ID: "router1", + Name: "router1", + Peers: []string{ + "peerA", + }, + }, + "router2": { + ID: "router2", + Name: "router2", + Peers: []string{ + "peerD", + }, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + Resources: []types.Resource{ + {ID: "resource2"}, + }, + }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + "pipeline": { + ID: "pipeline", + Name: "Pipeline", + Peers: []string{"peerM"}, + }, + "metrics": { + ID: "metrics", + Name: "Metrics", + Peers: []string{"peerN", "peerO"}, + Resources: []types.Resource{ + {ID: "resource6"}, + }, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1", + Name: "Finance Network", + }, + { + ID: "network2", + Name: "Devs Network", + }, + { + ID: "network3", + Name: "Contractors Network", + }, + { + ID: "network4", + Name: "QA Network", + }, + { + ID: "network5", + Name: "Pipeline Network", + }, + { + ID: "network6", + Name: "Metrics Network", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + Peer: "peerE", + PeerGroups: nil, + Masquerade: false, + Metric: 9999, + Enabled: true, + }, + { + ID: "router2", + NetworkID: "network2", + PeerGroups: []string{"router1", "router2"}, + Masquerade: false, + Metric: 9999, + Enabled: true, + }, + { + ID: "router3", + NetworkID: "network3", + Peer: "peerE", + PeerGroups: []string{}, + Enabled: true, + }, + { + ID: "router4", + NetworkID: "network4", + PeerGroups: []string{"router1"}, + Masquerade: false, + Metric: 9999, + Enabled: true, + }, + { + ID: "router5", + NetworkID: "network5", + Peer: "peerL", + Masquerade: false, + Metric: 9999, + Enabled: true, + }, + { + ID: "router6", + NetworkID: "network6", + Peer: "peerN", + Masquerade: false, + Metric: 9999, + Enabled: true, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "Resource 1", + Type: "subnet", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Enabled: true, + }, + { + ID: "resource2", + NetworkID: "network2", + Name: "Resource 2", + Type: "subnet", + Prefix: netip.MustParsePrefix("192.168.0.0/16"), + Enabled: true, + }, + { + ID: "resource3", + NetworkID: "network3", + Name: "Resource 3", + Type: "domain", + Domain: "example.com", + Enabled: true, + }, + { + ID: "resource4", + NetworkID: "network4", + Name: "Resource 4", + Type: "domain", + Domain: "example.com", + Enabled: true, + }, + { + ID: "resource5", + NetworkID: "network5", + Name: "Resource 5", + Type: "host", + Prefix: netip.MustParsePrefix("10.12.12.1/32"), + Enabled: true, + }, + { + ID: "resource6", + NetworkID: "network6", + Name: "Resource 6", + Type: "domain", + Domain: "*.google.com", + Enabled: true, + }, + }, + Policies: []*types.Policy{ + { + ID: "policyResource1", + Name: "Policy for resource 1", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource1", + Name: "ruleResource1", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + DestinationResource: types.Resource{ID: "resource1"}, + }, + }, + }, + { + ID: "policyResource2", + Name: "Policy for resource 2", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource2", + Name: "ruleResource2", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{"dev"}, + Destinations: []string{"dev"}, + }, + }, + }, + { + ID: "policyResource3", + Name: "policyResource3", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource3", + Name: "ruleResource3", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{"restrictQA"}, + Destinations: []string{"restrictQA"}, + }, + }, + }, + { + ID: "policyResource4", + Name: "policyResource4", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource4", + Name: "ruleResource4", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Sources: []string{"unrestrictedQA"}, + Destinations: []string{"unrestrictedQA"}, + }, + }, + }, + { + ID: "policyResource5", + Name: "policyResource5", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource5", + Name: "ruleResource5", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"8080"}, + Sources: []string{"pipeline"}, + DestinationResource: types.Resource{ID: "resource5"}, + }, + }, + }, + { + ID: "policyResource6", + Name: "policyResource6", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource6", + Name: "ruleResource6", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"9090"}, + Sources: []string{"metrics"}, + Destinations: []string{"metrics"}, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("validate applied policies for different network resources", func(t *testing.T) { + // Test case: Resource1 is directly applied to the policy (policyResource1) + policies := account.GetPoliciesForNetworkResource("resource1") + assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly") + + // Test case: Resource2 is applied to an access control group (dev), + // which is part of the destination in the policy (policyResource2) + policies = account.GetPoliciesForNetworkResource("resource2") + assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group") + + // Test case: Resource3 is not applied to any access control group or policy + policies = account.GetPoliciesForNetworkResource("resource3") + assert.Len(t, policies, 0, "resource3 should have no policies applied") + + // Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA), + // which is part of the destination in the policies (policyResource3 and policyResource4) + policies = account.GetPoliciesForNetworkResource("resource4") + assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") + + // Test case: Resource6 is applied to the access control groups (metrics), + policies = account.GetPoliciesForNetworkResource("resource6") + assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") + }) + + t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { + resourcePoliciesMap := account.GetResourcePoliciesMap() + resourceRoutersMap := account.GetResourceRoutersMap() + _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) + firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 4) + assert.Len(t, sourcePeers, 5) + + expectedFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + + additionalFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "tcp", + Port: 80, + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) + + // peerD is also the routing peer for resource2 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 2) + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + assert.Len(t, sourcePeers, 3) + + // peerE is a single routing peer for resource1 and resource3 + // PeerE should only receive rules for resource1 since resource3 has no applied policy + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 2) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: "10.10.10.0/24", + Protocol: "tcp", + PortRange: types.RulePortRange{Start: 80, End: 350}, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + // peerC is part of distribution groups for resource2 but should not receive the firewall rules + firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, firewallRules, 0) + + // peerL is the single routing peer for resource5 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 1) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.29.67/32"}, + Action: "accept", + Destination: "10.12.12.1/32", + Protocol: "tcp", + Port: 8080, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 0) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 2) + }) +} diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go new file mode 100644 index 000000000..37bc9f549 --- /dev/null +++ b/management/server/settings/manager.go @@ -0,0 +1,37 @@ +package settings + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return &types.Settings{}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index f055d877f..f2f1aad45 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,33 +2,16 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" - "hash/fnv" "slices" - "strconv" - "strings" "time" - "unicode/utf8" - "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" -) - -const ( - // SetupKeyReusable is a multi-use key (can be used for multiple machines) - SetupKeyReusable SetupKeyType = "reusable" - // SetupKeyOneOff is a single use key (can be used only once) - SetupKeyOneOff SetupKeyType = "one-off" - - // DefaultSetupKeyDuration = 1 month - DefaultSetupKeyDuration = 24 * 30 * time.Hour - // DefaultSetupKeyName is a default name of the default setup key - DefaultSetupKeyName = "Default key" - // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key - SetupKeyUnlimitedUsage = 0 + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -66,169 +49,14 @@ type SetupKeyUpdateOperation struct { Values []string } -// SetupKeyType is the type of setup key -type SetupKeyType string - -// SetupKey represents a pre-authorized key used to register machines (peers) -type SetupKey struct { - Id string - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Key string - KeySecret string - Name string - Type SetupKeyType - CreatedAt time.Time - ExpiresAt time.Time - UpdatedAt time.Time `gorm:"autoUpdateTime:false"` - // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) - Revoked bool - // UsedTimes indicates how many times the key was used - UsedTimes int - // LastUsed last time the key was used for peer registration - LastUsed time.Time - // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register - AutoGroups []string `gorm:"serializer:json"` - // UsageLimit indicates the number of times this key can be used to enroll a machine. - // The value of 0 indicates the unlimited usage. - UsageLimit int - // Ephemeral indicate if the peers will be ephemeral or not - Ephemeral bool -} - -// Copy copies SetupKey to a new object -func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, len(key.AutoGroups)) - copy(autoGroups, key.AutoGroups) - if key.UpdatedAt.IsZero() { - key.UpdatedAt = key.CreatedAt - } - return &SetupKey{ - Id: key.Id, - AccountID: key.AccountID, - Key: key.Key, - KeySecret: key.KeySecret, - Name: key.Name, - Type: key.Type, - CreatedAt: key.CreatedAt, - ExpiresAt: key.ExpiresAt, - UpdatedAt: key.UpdatedAt, - Revoked: key.Revoked, - UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, - AutoGroups: autoGroups, - UsageLimit: key.UsageLimit, - Ephemeral: key.Ephemeral, - } -} - -// EventMeta returns activity event meta related to the setup key -func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} -} - -// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. -// E.g., "831F6*******************************" -func hiddenKey(key string, length int) string { - prefix := key[0:5] - if length > utf8.RuneCountInString(key) { - length = utf8.RuneCountInString(key) - len(prefix) - } - return prefix + strings.Repeat("*", length) -} - -// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now -func (key *SetupKey) IncrementUsage() *SetupKey { - c := key.Copy() - c.UsedTimes++ - c.LastUsed = time.Now().UTC() - return c -} - -// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to -func (key *SetupKey) IsValid() bool { - return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() -} - -// IsRevoked if key was revoked -func (key *SetupKey) IsRevoked() bool { - return key.Revoked -} - -// IsExpired if key was expired -func (key *SetupKey) IsExpired() bool { - if key.ExpiresAt.IsZero() { - return false - } - return time.Now().After(key.ExpiresAt) -} - -// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. -func (key *SetupKey) IsOverUsed() bool { - limit := key.UsageLimit - if key.Type == SetupKeyOneOff { - limit = 1 - } - return limit > 0 && key.UsedTimes >= limit -} - -// GenerateSetupKey generates a new setup key -func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) (*SetupKey, string) { - key := strings.ToUpper(uuid.New().String()) - limit := usageLimit - if t == SetupKeyOneOff { - limit = 1 - } - - expiresAt := time.Time{} - if validFor != 0 { - expiresAt = time.Now().UTC().Add(validFor) - } - - hashedKey := sha256.Sum256([]byte(key)) - encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - - return &SetupKey{ - Id: strconv.Itoa(int(Hash(key))), - Key: encodedHashedKey, - KeySecret: hiddenKey(key, 4), - Name: name, - Type: t, - CreatedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - UpdatedAt: time.Now().UTC(), - Revoked: false, - UsedTimes: 0, - AutoGroups: autoGroups, - UsageLimit: limit, - Ephemeral: ephemeral, - }, key -} - -// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() (*SetupKey, string) { - return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, - SetupKeyUnlimitedUsage, false) -} - -func Hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} - // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -241,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - var setupKey *SetupKey + var setupKey *types.SetupKey var plainKey string var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) }) if err != nil { return nil, err @@ -276,8 +104,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // SaveSetupKey saves the provided SetupKey to the database overriding the existing one. // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). -// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { +// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } @@ -285,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -298,34 +126,37 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewAdminPermissionError() } - var oldKey *SetupKey - var newKey *SetupKey + var oldKey *types.SetupKey + var newKey *types.SetupKey var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err } - // only auto groups, revoked status, and name can be updated for now + if oldKey.Revoked && !keyToSave.Revoked { + return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key") + } + + // only auto groups, revoked status (from false to true) can be updated newKey = oldKey.Copy() - newKey.Name = keyToSave.Name newKey.AutoGroups = keyToSave.AutoGroups newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) }) if err != nil { return nil, err @@ -343,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -357,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -375,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } @@ -390,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -403,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return status.NewAdminPermissionError() } - var deletedSetupKey *SetupKey + var deletedSetupKey *types.SetupKey - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return err } - return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) }) if err != nil { return err @@ -422,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) +func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) if err != nil { return err } @@ -443,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI } // prepareSetupKeyEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index ea239ec0c..e225ec54b 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -30,7 +30,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,18 +49,16 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, - SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{}, + types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } autoGroups := []string{"group_1", "group_2"} - newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, + assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.GetExpiresAt(), key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated @@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { assert.NotNil(t, ev) assert.Equal(t, account.Id, ev.AccountID) - assert.Equal(t, newKeyName, ev.Meta["name"]) + assert.Equal(t, keyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) @@ -87,9 +85,8 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -108,7 +105,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -117,7 +114,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -170,8 +167,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, - tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn, + tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { if err == nil { @@ -185,7 +182,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, - tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), + tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated @@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_1", - Name: "group_name_1", - Peers: []string{}, - }) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_2", - Name: "group_name_2", - Peers: []string{}, - }) - if err != nil { - t.Fatal(err) + type testCase struct { + name string + keyId string + expectedFailure bool + } + + testCase1 := testCase{ + name: "Should get existing Setup Key", + keyId: plainKey.Id, + expectedFailure: false, + } + testCase2 := testCase{ + name: "Should fail to get non-existent Setup Key", + keyId: "some key", + expectedFailure: true, + } + + for _, tCase := range []testCase{testCase1, testCase2} { + t.Run(tCase.name, func(t *testing.T) { + key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId) + + if tCase.expectedFailure { + if err == nil { + t.Fatal("expected to fail") + } + return + } + + assert.NotEqual(t, plainKey.Key, key.Key) + }) } } @@ -242,10 +258,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key, plainKey := GenerateDefaultSetupKey() + key, plainKey := types.GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -259,48 +275,48 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) } } -func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, +func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() @@ -320,8 +336,8 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke t.Errorf("expected setup key to have UsedTimes = %v, got %v", expectedUsedTimes, key.UsedTimes) } - if key.ExpiresAt.Sub(expectedExpiresAt).Round(time.Hour) != 0 { - t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt) + if key.GetExpiresAt().Sub(expectedExpiresAt).Round(time.Hour) != 0 { + t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.GetExpiresAt()) } if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 { @@ -372,10 +388,10 @@ func isValidBase64SHA256(encodedKey string) bool { func TestSetupKey_Copy(t *testing.T) { - key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() - assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, + assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.GetExpiresAt(), key.Id, key.UpdatedAt, key.AutoGroups, true) } @@ -383,22 +399,22 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) assert.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"group"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -410,7 +426,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var setupKey *SetupKey + var setupKey *types.SetupKey // Creating setup key should not update account peers and not send peer update t.Run("creating setup key", func(t *testing.T) { @@ -420,7 +436,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { close(done) }() - setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) assert.NoError(t, err) select { @@ -448,3 +464,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } }) } + +func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + if err != nil { + t.Fatal(err) + } + + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) + assert.NoError(t, err) + + // revoke the key + updateKey := key.Copy() + updateKey.Revoked = true + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.NoError(t, err) + + // re-activate revoked key + updateKey.Revoked = false + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.Error(t, err, "should not allow to update revoked key") + +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 045469306..96b103183 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -160,6 +160,38 @@ func NewNameServerGroupNotFoundError(nsGroupID string) error { return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) } +// NewNetworkNotFoundError creates a new Error with NotFound type for a missing network. +func NewNetworkNotFoundError(networkID string) error { + return Errorf(NotFound, "network: %s not found", networkID) +} + +// NewNetworkRouterNotFoundError creates a new Error with NotFound type for a missing network router. +func NewNetworkRouterNotFoundError(routerID string) error { + return Errorf(NotFound, "network router: %s not found", routerID) +} + +// NewNetworkResourceNotFoundError creates a new Error with NotFound type for a missing network resource. +func NewNetworkResourceNotFoundError(resourceID string) error { + return Errorf(NotFound, "network resource: %s not found", resourceID) +} + +// NewPermissionDeniedError creates a new Error with PermissionDenied type for a permission denied error. +func NewPermissionDeniedError() error { + return Errorf(PermissionDenied, "permission denied") +} + +func NewPermissionValidationError(err error) error { + return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) +} + +func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { + return Errorf(BadRequest, "resource %s is not part of the network %s", resourceID, networkID) +} + +func NewRouterNotPartOfNetworkError(routerID, networkID string) error { + return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID) +} + // NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role func NewServiceUserRoleInvalidError() error { return Errorf(InvalidArgument, "can't create a service user with owner role") diff --git a/management/server/file_store.go b/management/server/store/file_store.go similarity index 86% rename from management/server/file_store.go rename to management/server/store/file_store.go index 561e133ce..4c9134e41 100644 --- a/management/server/file_store.go +++ b/management/server/store/file_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -11,9 +11,10 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + nbutil "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/util" ) @@ -22,7 +23,7 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { - Accounts map[string]*Account + Accounts map[string]*types.Account SetupKeyID2AccountID map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"` PeerID2AccountID map[string]string `json:"-"` @@ -55,7 +56,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ - Accounts: make(map[string]*Account), + Accounts: make(map[string]*types.Account), mux: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), @@ -92,12 +93,14 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for accountID, account := range store.Accounts { if account.Settings == nil { - account.Settings = &Settings{ + account.Settings = &types.Settings{ PeerLoginExpirationEnabled: false, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: true, } } @@ -112,7 +115,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID if user.Issued == "" { - user.Issued = UserIssuedAPI + user.Issued = types.UserIssuedAPI account.Users[user.Id] = user } @@ -122,7 +125,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { } } - if account.Domain != "" && account.DomainCategory == PrivateCategory && + if account.Domain != "" && account.DomainCategory == types.PrivateCategory && account.IsDomainPrimaryAccount { store.PrivateDomain2AccountID[account.Domain] = accountID } @@ -134,20 +137,20 @@ func restore(ctx context.Context, file string) (*FileStore, error) { policy.UpgradeAndFix() } if account.Policies == nil { - account.Policies = make([]*Policy, 0) + account.Policies = make([]*types.Policy, 0) } // for data migration. Can be removed once most base will be with labels - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(ctx, account, existingLabels) + types.AddPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI } } @@ -173,8 +176,8 @@ func restore(ctx context.Context, file string) (*FileStore, error) { migrationPeers := make(map[string]*nbpeer.Peer) // key to Peer for key, peer := range account.Peers { // set LastLogin for the peers that were onboarded before the peer login expiration feature - if peer.LastLogin.IsZero() { - peer.LastLogin = time.Now().UTC() + if peer.GetLastLogin().IsZero() { + peer.LastLogin = nbutil.ToPtr(time.Now().UTC()) } if peer.ID != "" { continue @@ -223,7 +226,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // It is recommended to call it with locking FileStore.mux func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() - err := util.WriteJson(file, s) + err := util.WriteJson(context.Background(), file, s) if err != nil { return err } @@ -236,7 +239,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -257,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error { } // GetStoreEngine returns FileStoreEngine -func (s *FileStore) GetStoreEngine() StoreEngine { +func (s *FileStore) GetStoreEngine() Engine { return FileStoreEngine } diff --git a/management/server/sql_store.go b/management/server/store/sql_store.go similarity index 74% rename from management/server/sql_store.go rename to management/server/store/sql_store.go index 23b253c5a..d5cee567f 100644 --- a/management/server/sql_store.go +++ b/management/server/store/sql_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -16,6 +16,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -24,11 +25,14 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -36,6 +40,7 @@ const ( storeSqliteFileName = "store.db" idQueryCondition = "id = ?" keyQueryCondition = "key = ?" + mysqlKeyQueryCondition = "`key` = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" @@ -49,7 +54,7 @@ type SqlStore struct { globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int - storeEngine StoreEngine + storeEngine Engine } type installation struct { @@ -60,7 +65,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -86,9 +91,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, - &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, ) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) @@ -97,6 +103,13 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil } +func GetKeyQueryCondition(s *SqlStore) string { + if s.storeEngine == MysqlStoreEngine { + return mysqlKeyQueryCondition + } + return keyQueryCondition +} + // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { log.WithContext(ctx).Tracef("acquiring global lock") @@ -151,7 +164,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u return unlock } -func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() defer func() { elapsed := time.Since(start) @@ -201,7 +214,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { } // generateAccountSQLTypes generates the GORM compatible types for the account -func generateAccountSQLTypes(account *Account) { +func generateAccountSQLTypes(account *types.Account) { for _, key := range account.SetupKeys { account.SetupKeysG = append(account.SetupKeysG, *key) } @@ -238,7 +251,7 @@ func generateAccountSQLTypes(account *Account) { // checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { - var acc Account + var acc types.Account var domain string result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) if result.Error != nil { @@ -252,7 +265,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, } } -func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -319,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error) } return nil @@ -333,19 +346,19 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { - accountCopy := Account{ + accountCopy := types.Account{ Domain: domain, DomainCategory: category, IsDomainPrimaryAccount: isPrimaryDomain, } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). + result := s.db.Model(&types.Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -368,7 +381,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -390,7 +403,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr Updates(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -401,7 +414,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr } // SaveUsers saves the given list of users to the database. -func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error { +func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error { if len(users) == 0 { return nil } @@ -415,7 +428,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, } // SaveUser saves the given user to the database. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) @@ -425,7 +438,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -447,7 +460,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { return nil, err @@ -459,9 +472,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory, + strings.ToLower(domain), true, types.PrivateCategory, ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -474,9 +487,9 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return accountID, nil } -func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { - var key SetupKey - result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { + var key types.SetupKey + result := s.db.Select("account_id").First(&key, GetKeyQueryCondition(s), setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(setupKey) @@ -493,7 +506,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - var token PersonalAccessToken + var token types.PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -506,8 +519,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return token.ID, nil } -func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) { - var user User +func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) { + var user types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Where("personal_access_tokens.id = ?", patID).First(&user) @@ -522,8 +535,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { - var user User +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + var user types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -538,13 +551,13 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error { err := s.db.Transaction(func(tx *gorm.DB) error { result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&PersonalAccessToken{}, "user_id = ?", userID) + Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error } return tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&User{}, accountAndIDQueryCondition, accountID, userID).Error + Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error }) if err != nil { log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) @@ -568,8 +581,8 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -582,8 +595,27 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return groups, nil } -func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { - var accounts []Account +func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + var groups []*types.Group + + likePattern := `%"ID":"` + resourceID + `"%` + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("resources LIKE ?", likePattern). + Find(&groups) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, result.Error + } + + return groups, nil +} + +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { + var accounts []types.Account result := s.db.Find(&accounts) if result.Error != nil { return all @@ -598,7 +630,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { return all } -func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -607,7 +639,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } }() - var account Account + var account types.Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). @@ -622,15 +654,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us for i, policy := range account.Policies { - var rules []*PolicyRule - err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error if err != nil { return nil, status.Errorf(status.NotFound, "rule not found") } account.Policies[i].Rules = rules } - account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { account.SetupKeys[key.Key] = key.Copy() } @@ -642,9 +674,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.PeersG = nil - account.Users = make(map[string]*User, len(account.UsersG)) + account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -652,7 +684,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.UsersG = nil - account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } @@ -673,8 +705,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, return &account, nil } -func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { - var user User +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { + var user types.User result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -690,7 +722,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { @@ -707,9 +739,10 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) + result := s.db.Select("account_id").First(&peer, GetKeyQueryCondition(s), peerKey) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -727,7 +760,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) + result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -740,7 +773,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.User{}). Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -768,7 +801,7 @@ func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength Lockin func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -828,9 +861,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return labels, nil } -func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + var accountNetwork types.AccountNetwork + if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -841,7 +874,8 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, GetKeyQueryCondition(s), peerKey) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPeerNotFoundError(peerKey) @@ -852,9 +886,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { - var accountSettings AccountSettings - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + var accountSettings types.AccountSettings + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -879,7 +913,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + var user types.User result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -887,9 +921,13 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri } return status.NewGetUserFromStoreError() } - user.LastLogin = lastLogin - return s.db.Save(&user).Error + if !lastLogin.IsZero() { + user.LastLogin = &lastLogin + return s.db.Save(&user).Error + } + + return nil } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -917,7 +955,7 @@ func (s *SqlStore) Close(_ context.Context) error { } // GetStoreEngine returns underlying store engine -func (s *SqlStore) GetStoreEngine() StoreEngine { +func (s *SqlStore) GetStoreEngine() Engine { return s.storeEngine } @@ -948,6 +986,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) } +// NewMysqlStore creates a new MySQL store. +func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig()) + if err != nil { + return nil, err + } + + return NewSqlStore(ctx, db, MysqlStoreEngine, metrics) +} + func getGormConfig() *gorm.Config { return &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), @@ -965,6 +1013,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, return NewPostgresqlStore(ctx, dsn, metrics) } +// newMysqlStore initializes a new MySQL store. +func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, fmt.Errorf("%s is not set", mysqlDsnEnv) + } + return NewMysqlStore(ctx, dsn, metrics) +} + // NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir. func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewSqliteStore(ctx, dataDir, metrics) @@ -1009,10 +1066,33 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } -func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - var setupKey SetupKey +// NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. +func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewMysqlStore(ctx, dsn, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqliteStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + } + + return store, nil +} + +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + var setupKey types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, keyQueryCondition, key) + First(&setupKey, GetKeyQueryCondition(s), key) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(key) @@ -1024,7 +1104,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.Model(&SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1042,9 +1122,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group - result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) +// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { + var group types.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&group, "account_id = ? AND name = ?", accountID, "All") if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") @@ -1060,16 +1142,18 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil } -func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) +// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction +func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { + var group types.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). + First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1086,6 +1170,33 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + +// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction +func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for _, res := range group.Resources { + if res.ID == resource.ID { + return nil + } + } + + group.Resources = append(group.Resources, *resource) + if err := s.db.Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1093,6 +1204,45 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction +func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for i, res := range group.Resources { + if res.ID == resourceID { + group.Resources = append(group.Resources[:i], group.Resources[i+1:]...) + break + } + } + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + +// GetPeerGroups retrieves all groups assigned to a specific peer in a given account. +func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { + var groups []*types.Group + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) + + if query.Error != nil { + return nil, query.Error + } + + return groups, nil +} + // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer @@ -1108,6 +1258,12 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer + + // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. + if userID == "" { + return peers, nil + } + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) if err := result.Error; err != nil { @@ -1118,8 +1274,8 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt return peers, nil } -func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.Create(peer).Error; err != nil { +func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1223,7 +1379,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1232,6 +1388,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + startTime := time.Now() tx := s.db.Begin() if tx.Error != nil { return tx.Error @@ -1242,7 +1399,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor tx.Rollback() return err } - return tx.Commit().Error + + err = tx.Commit().Error + + log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) + if s.metrics != nil { + s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime)) + } + + return err } func (s *SqlStore) withTx(tx *gorm.DB) Store { @@ -1256,9 +1421,9 @@ func (s *SqlStore) GetDB() *gorm.DB { return s.db } -func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { + var accountDNSSettings types.AccountDNSSettings + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1273,7 +1438,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1287,8 +1452,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + var account types.Account + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1301,8 +1466,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { - var group *nbgroup.Group +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + var group *types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1316,15 +1481,19 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { - var group nbgroup.Group +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) - if s.storeEngine == PostgresStoreEngine { + + switch s.storeEngine { + case PostgresStoreEngine: query = query.Order("json_array_length(peers::json) DESC") - } else { + case MysqlStoreEngine: + query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") + default: query = query.Order("json_array_length(peers) DESC") } @@ -1340,15 +1509,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetGroupsByIDs retrieves groups by their IDs and account ID. -func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } - groupsMap := make(map[string]*nbgroup.Group) + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -1357,7 +1526,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) @@ -1369,7 +1538,7 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete group from store") @@ -1385,18 +1554,18 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") } return nil } // GetAccountPolicies retrieves policies for an account. -func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - var policies []*Policy +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { + var policies []*types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1408,8 +1577,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { - var policy *Policy +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { + var policy *types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). First(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { @@ -1423,7 +1592,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } -func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) @@ -1434,7 +1603,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt } // SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) if err := result.Error; err != nil { @@ -1446,7 +1615,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) return status.Errorf(status.Internal, "failed to delete policy from store") @@ -1542,8 +1711,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // GetAccountSetupKeys retrieves setup keys for an account. -func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - var setupKeys []*SetupKey +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { + var setupKeys []*types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1555,8 +1724,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { - var setupKey *SetupKey +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { + var setupKey *types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { @@ -1571,7 +1740,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) @@ -1583,7 +1752,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") @@ -1683,9 +1852,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } // SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save dns settings to store") @@ -1698,6 +1867,201 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } +func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { + var networks []*networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get networks from store") + } + + return networks, nil +} + +func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { + var network *networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&network, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkNotFoundError(networkID) + } + + log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network from store") + } + + return network, nil +} + +func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkNotFoundError(networkID) + } + + return nil +} + +func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { + var netRouter *routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netRouter, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkRouterNotFoundError(routerID) + } + log.WithContext(ctx).Errorf("failed to get network router from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network router from store") + } + + return netRouter, nil +} + +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network router to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network router from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkRouterNotFoundError(routerID) + } + + return nil +} + +func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceID) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceName) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network resource to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network resource from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkResourceNotFoundError(resourceID) + } + + return nil +} + // GetPATByHashedToken returns a PersonalAccessToken by its hashed token. func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { var pat PersonalAccessToken @@ -1776,7 +2140,7 @@ func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pa // DeletePAT deletes a personal access token from the database. func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) + Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err) return status.Errorf(status.Internal, "failed to delete pat from store") diff --git a/management/server/sql_store_test.go b/management/server/store/sql_store_test.go similarity index 70% rename from management/server/sql_store_test.go rename to management/server/store/sql_store_test.go index 26cc653e8..490586271 100644 --- a/management/server/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -14,17 +14,25 @@ import ( "time" "github.com/google/uuid" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/util" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" ) func TestSqlite_NewStore(t *testing.T) { @@ -73,7 +81,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -86,14 +94,14 @@ func runLargeTest(t *testing.T, store Store) { IP: netIP, Name: peerID, DNSLabel: peerID, - UserID: userID, + UserID: "testuser", Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } account.Peers[peerID] = peer group, _ := account.GetGroupAll() group.Peers = append(group.Peers, peerID) - user := &User{ + user := &types.User{ Id: fmt.Sprintf("%s-user-%d", account.Id, n), AccountID: account.Id, } @@ -111,7 +119,7 @@ func runLargeTest(t *testing.T, store Store) { } account.Routes[route.ID] = route - group = &nbgroup.Group{ + group = &types.Group{ ID: fmt.Sprintf("group-id-%d", n), AccountID: account.Id, Name: fmt.Sprintf("group-id-%d", n), @@ -134,7 +142,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -216,7 +224,7 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -230,7 +238,7 @@ func TestSqlite_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -289,14 +297,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -306,6 +314,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user + account.Networks = []*networkTypes.Network{ + { + ID: "network_id", + AccountID: account.Id, + Name: "network name", + Description: "network description", + }, + } + account.NetworkRouters = []*routerTypes.NetworkRouter{ + { + ID: "router_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + PeerGroups: []string{"group_id"}, + Masquerade: true, + Metric: 1, + }, + } + account.NetworkResources = []*resourceTypes.NetworkResource{ + { + ID: "resource_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + Name: "Name", + Description: "Description", + Type: "Domain", + Address: "example.com", + }, + } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -337,21 +374,30 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") } + for _, network := range account.Networks { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network routers") + require.Len(t, routers, 0, "expecting no network routers to be found after DeleteAccount") + + resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") + require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") + } } func TestSqlite_GetAccount(t *testing.T) { @@ -360,7 +406,7 @@ func TestSqlite_GetAccount(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -377,8 +423,8 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_SavePeer(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) +func TestSqlStore_SavePeer(t *testing.T) { + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -387,12 +433,13 @@ func TestSqlite_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + CreatedAt: time.Now().UTC(), } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -422,11 +469,11 @@ func TestSqlite_SavePeer(t *testing.T) { assert.Equal(t, updatedPeer.Status.Connected, actual.Status.Connected) assert.Equal(t, updatedPeer.Status.LoginExpired, actual.Status.LoginExpired) assert.Equal(t, updatedPeer.Status.RequiresApproval, actual.Status.RequiresApproval) - assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerStatus(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) +func TestSqlStore_SavePeerStatus(t *testing.T) { + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -434,7 +481,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().Local()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -448,7 +495,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -464,7 +511,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") newStatus.Connected = true @@ -478,11 +525,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerLocation(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) +func TestSqlStore_SavePeerLocation(t *testing.T) { + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -498,7 +545,8 @@ func TestSqlite_SavePeerLocation(t *testing.T) { CityName: "City", GeoNameID: 1, }, - Meta: nbpeer.PeerSystemMeta{}, + CreatedAt: time.Now().UTC(), + Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) @@ -536,7 +584,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -559,7 +607,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -596,7 +644,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to parse CIDR") type network struct { - Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } @@ -611,7 +659,7 @@ func TestMigrate(t *testing.T) { } type account struct { - Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` Peers []peer `gorm:"foreignKey:AccountID;references:id"` } @@ -671,23 +719,10 @@ func TestMigrate(t *testing.T) { } -func newSqliteStore(t *testing.T) *SqlStore { - t.Helper() - - store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - t.Cleanup(func() { - store.Close(context.Background()) - }) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, @@ -726,7 +761,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -740,7 +775,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -799,14 +834,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -847,16 +882,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -864,54 +899,13 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } -func TestPostgresql_SavePeerStatus(t *testing.T) { - if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { - t.Skip("skip CI tests on darwin and windows") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) - require.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus.Connected, actual.Connected) -} - func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -931,7 +925,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -945,7 +939,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { func TestSqlite_GetTakenIPs(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -965,7 +959,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -978,7 +972,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -990,7 +984,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1010,7 +1004,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1022,7 +1016,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer2.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1032,7 +1026,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1055,7 +1049,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetSetupKeyBySecret(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1073,14 +1067,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, encodedHashedKey, setupKey.Key) - assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) + assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1116,13 +1110,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: "account-id", Name: "group-name", @@ -1161,7 +1155,7 @@ func TestSqlite_GetAccountUsers(t *testing.T) { } func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1207,7 +1201,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { } func TestSqlite_GetGroupByName(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1221,7 +1215,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1237,7 +1231,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1249,7 +1243,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { } func TestSqlStore_GetGroupsByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1292,13 +1286,13 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } func TestSqlStore_SaveGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: accountID, Issued: "api", @@ -1313,13 +1307,13 @@ func TestSqlStore_SaveGroup(t *testing.T) { } func TestSqlStore_SaveGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - groups := []*nbgroup.Group{ + groups := []*types.Group{ { ID: "group-1", AccountID: accountID, @@ -1338,7 +1332,7 @@ func TestSqlStore_SaveGroups(t *testing.T) { } func TestSqlStore_DeleteGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1386,7 +1380,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { } func TestSqlStore_DeleteGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1433,7 +1427,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) { } func TestSqlStore_GetPeerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1479,7 +1473,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) { } func TestSqlStore_GetPeersByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1521,7 +1515,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { } func TestSqlStore_GetPostureChecksByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1567,7 +1561,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1610,7 +1604,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { } func TestSqlStore_SavePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1651,7 +1645,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) { } func TestSqlStore_DeletePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1698,7 +1692,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { } func TestSqlStore_GetPolicyByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1744,23 +1738,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) { } func TestSqlStore_CreatePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - policy := &Policy{ + policy := &types.Policy{ ID: "policy-id", AccountID: accountID, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1774,7 +1768,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) { } func TestSqlStore_SavePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1786,6 +1780,8 @@ func TestSqlStore_SavePolicy(t *testing.T) { policy.Enabled = false policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) require.NoError(t, err) @@ -1795,7 +1791,7 @@ func TestSqlStore_SavePolicy(t *testing.T) { } func TestSqlStore_DeletePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1811,7 +1807,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) { } func TestSqlStore_GetDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1855,7 +1851,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) { } func TestSqlStore_SaveDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1874,7 +1870,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { } func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1911,7 +1907,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { } func TestSqlStore_GetNameServerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1957,7 +1953,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) { } func TestSqlStore_SaveNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1989,7 +1985,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) { } func TestSqlStore_DeleteNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2004,8 +2000,97 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Nil(t, nsGroup) } -func TestSqlStore_GetAccountPeers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[nbroute.ID]*nbroute.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + }, + } + + if err := addAllGroup(acc); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} + +// addAllGroup to account object if it doesn't exist +func addAllGroup(account *types.Account) error { + if len(account.Groups) == 0 { + allGroup := &types.Group{ + ID: xid.New().String(), + Name: "All", + Issued: types.GroupIssuedAPI, + } + for _, peer := range account.Peers { + allGroup.Peers = append(allGroup.Peers, peer.ID) + } + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} + + id := xid.New().String() + + defaultPolicy := &types.Policy{ + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + + account.Policies = []*types.Policy{defaultPolicy} + } + return nil +} + +func TestSqlStore_GetAccountNetworks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2015,17 +2100,544 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { expectedCount int }{ { - name: "retrieve peers by existing account ID", + name: "retrieve networks by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + + { + name: "retrieve networks by non-existing account ID", + accountID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, networks, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkID string + expectError bool + }{ + { + name: "retrieve existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectError: false, + }, + { + name: "retrieve non-existing network ID", + networkID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty ID", + networkID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, network) + } else { + require.NoError(t, err) + require.NotNil(t, network) + require.Equal(t, tt.networkID, network.ID) + } + }) + } +} + +func TestSqlStore_SaveNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + network := &networkTypes.Network{ + ID: "net-id", + AccountID: accountID, + Name: "net", + } + + err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + require.NoError(t, err) + + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + require.NoError(t, err) + require.Equal(t, network, savedNet) +} + +func TestSqlStore_DeleteNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + require.NoError(t, err) + + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, network) +} + +func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve routers by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve routers by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, routers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkRouterByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkRouterID string + expectError bool + }{ + { + name: "retrieve existing network router ID", + networkRouterID: "ctc20ji7qv9ck2sebc80", + expectError: false, + }, + { + name: "retrieve non-existing network router ID", + networkRouterID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty router ID", + networkRouterID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, tt.networkRouterID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, networkRouter) + } else { + require.NoError(t, err) + require.NotNil(t, networkRouter) + require.Equal(t, tt.networkRouterID, networkRouter.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true) + require.NoError(t, err) + + err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) + require.NoError(t, err) + + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, netRouter.ID) + require.NoError(t, err) + require.Equal(t, netRouter, savedNetRouter) +} + +func TestSqlStore_DeleteNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netRouterID := "ctc20ji7qv9ck2sebc80" + + err = store.DeleteNetworkRouter(context.Background(), LockingStrengthUpdate, accountID, netRouterID) + require.NoError(t, err) + + netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netRouterID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netRouter) +} + +func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve resources by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve resources by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, netResources, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkResourceByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + netResourceID string + expectError bool + }{ + { + name: "retrieve existing network resource ID", + netResourceID: "ctc4nci7qv9061u6ilfg", + expectError: false, + }, + { + name: "retrieve non-existing network resource ID", + netResourceID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty resource ID", + netResourceID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, tt.netResourceID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, netResource) + } else { + require.NoError(t, err) + require.NotNil(t, netResource) + require.Equal(t, tt.netResourceID, netResource.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}, true) + require.NoError(t, err) + + err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) + require.NoError(t, err) + + savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, netResource.ID) + require.NoError(t, err) + require.Equal(t, netResource.ID, savedNetResource.ID) + require.Equal(t, netResource.Name, savedNetResource.Name) + require.Equal(t, netResource.NetworkID, savedNetResource.NetworkID) + require.Equal(t, netResource.Type, resourceTypes.NetworkResourceType("domain")) + require.Equal(t, netResource.Domain, "example.com") + require.Equal(t, netResource.AccountID, savedNetResource.AccountID) + require.Equal(t, netResource.Prefix, netip.Prefix{}) +} + +func TestSqlStore_DeleteNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netResourceID := "ctc4nci7qv9061u6ilfg" + + err = store.DeleteNetworkResource(context.Background(), LockingStrengthUpdate, accountID, netResourceID) + require.NoError(t, err) + + netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netResourceID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netResource) +} + +func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + resourceId := "ctc4nci7qv9061u6ilfg" + groupID := "cs1tnh0hhcjnqoiuebeg" + + res := &types.Resource{ + ID: resourceId, + Type: "host", + } + err = store.AddResourceToGroup(context.Background(), accountID, groupID, res) + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.Contains(t, group.Resources, *res) + + groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId) + require.NoError(t, err) + require.Len(t, groups, 1) + + err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID) + require.NoError(t, err) + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.NotContains(t, group.Resources, *res) +} + +func TestSqlStore_AddPeerToGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + groupID := "cfefqs706sqkneg59g4h" + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 0, "group should have 0 peers") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + require.NoError(t, err, "failed to add peer to group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 1, "group should have 1 peers") + require.Contains(t, group.Peers, peerID) +} + +func TestSqlStore_AddPeerToAllGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + groupID := "cfefqs706sqkneg59g3g" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.domain.test", + } + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 2, "group should have 2 peers") + require.NotContains(t, group.Peers, peer.ID) + + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + require.NoError(t, err, "failed to add peer to all group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 3, "group should have peers") + require.Contains(t, group.Peers, peer.ID) +} + +func TestSqlStore_AddPeerToAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + Key: "key", + IP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "hostname", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + Name: "peer.test", + DNSLabel: "peer", + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + SSHKey: "ssh-key", + SSHEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + LastLogin: util.ToPtr(time.Now().UTC()), + CreatedAt: time.Now().UTC(), + Ephemeral: true, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peer.ID) + require.NoError(t, err, "failed to get peer") + + assert.Equal(t, peer.ID, storedPeer.ID) + assert.Equal(t, peer.AccountID, storedPeer.AccountID) + assert.Equal(t, peer.Key, storedPeer.Key) + assert.Equal(t, peer.IP.String(), storedPeer.IP.String()) + assert.Equal(t, peer.Meta, storedPeer.Meta) + assert.Equal(t, peer.Name, storedPeer.Name) + assert.Equal(t, peer.DNSLabel, storedPeer.DNSLabel) + assert.Equal(t, peer.SSHKey, storedPeer.SSHKey) + assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) + assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) + assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) + assert.WithinDurationf(t, peer.GetLastLogin(), storedPeer.GetLastLogin().UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) + assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) + assert.Equal(t, peer.Status.LoginExpired, storedPeer.Status.LoginExpired) + assert.Equal(t, peer.Status.RequiresApproval, storedPeer.Status.RequiresApproval) + assert.WithinDurationf(t, peer.Status.LastSeen, storedPeer.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") +} + +func TestSqlStore_GetPeerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + + groups, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 1) + assert.Equal(t, groups[0].Name, "All") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h") + require.NoError(t, err) + + groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 2) +} + +func TestSqlStore_GetAccountPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", expectedCount: 4, }, { - name: "non-existing account ID", + name: "should return no peers for a non-existing account ID", accountID: "nonexistent", expectedCount: 0, }, { - name: "empty account ID", + name: "should return no peers for an empty account ID", accountID: "", expectedCount: 0, }, @@ -2042,7 +2654,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { } func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2052,17 +2664,17 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { expectedCount int }{ { - name: "retrieve peers with expiration by existing account ID", + name: "should retrieve peers with expiration for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", expectedCount: 1, }, { - name: "non-existing account ID", + name: "should return no peers with expiration for a non-existing account ID", accountID: "nonexistent", expectedCount: 0, }, { - name: "empty account ID", + name: "should return no peers with expiration for a empty account ID", accountID: "", expectedCount: 0, }, @@ -2078,7 +2690,7 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { } func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2088,17 +2700,17 @@ func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { expectedCount int }{ { - name: "retrieve peers with inactivity by existing account ID", + name: "should retrieve peers with inactivity for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", expectedCount: 1, }, { - name: "non-existing account ID", + name: "should return no peers with inactivity for a non-existing account ID", accountID: "nonexistent", expectedCount: 0, }, { - name: "empty account ID", + name: "should return no peers with inactivity for an empty account ID", accountID: "", expectedCount: 0, }, @@ -2114,7 +2726,7 @@ func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { } func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/storev1.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2124,8 +2736,60 @@ func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { require.True(t, peers[0].Ephemeral) } +func TestSqlStore_GetUserPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + userID string + expectedCount int + }{ + { + name: "should retrieve peers for existing account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 1, + }, + { + name: "should return no peers for non-existing account ID with existing user ID", + accountID: "nonexistent", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 0, + }, + { + name: "should return no peers for non-existing user ID with existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "nonexistent_user", + expectedCount: 0, + }, + { + name: "should retrieve peers for another valid account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "edafee4e-63fb-11ec-90d6-0242ac120003", + expectedCount: 2, + }, + { + name: "should return no peers for existing account ID with empty user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetUserPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.userID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + func TestSqlStore_DeletePeer(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2141,7 +2805,7 @@ func TestSqlStore_DeletePeer(t *testing.T) { } func TestSqlStore_GetAccountCreatedBy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2189,7 +2853,7 @@ func TestSqlStore_GetAccountCreatedBy(t *testing.T) { } func TestSqlStore_GetUserByUserID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2234,7 +2898,7 @@ func TestSqlStore_GetUserByUserID(t *testing.T) { } func TestSqlStore_GetUserByPATID(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -2246,7 +2910,7 @@ func TestSqlStore_GetUserByPATID(t *testing.T) { } func TestSqlStore_SaveUser(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2280,7 +2944,7 @@ func TestSqlStore_SaveUser(t *testing.T) { } func TestSqlStore_SaveUsers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2313,7 +2977,7 @@ func TestSqlStore_SaveUsers(t *testing.T) { } func TestSqlStore_DeleteUser(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2333,7 +2997,7 @@ func TestSqlStore_DeleteUser(t *testing.T) { } func TestSqlStore_GetPATByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2380,7 +3044,7 @@ func TestSqlStore_GetPATByID(t *testing.T) { } func TestSqlStore_GetUserPATs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2390,7 +3054,7 @@ func TestSqlStore_GetUserPATs(t *testing.T) { } func TestSqlStore_GetPATByHashedToken(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2400,7 +3064,7 @@ func TestSqlStore_GetPATByHashedToken(t *testing.T) { } func TestSqlStore_MarkPATUsed(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2417,7 +3081,7 @@ func TestSqlStore_MarkPATUsed(t *testing.T) { } func TestSqlStore_SavePAT(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2448,7 +3112,7 @@ func TestSqlStore_SavePAT(t *testing.T) { } func TestSqlStore_DeletePAT(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) diff --git a/management/server/store.go b/management/server/store/store.go similarity index 68% rename from management/server/store.go rename to management/server/store/store.go index 01a4955c1..29ed22fa5 100644 --- a/management/server/store.go +++ b/management/server/store/store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -18,16 +18,18 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/migration" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/route" ) @@ -41,59 +43,60 @@ const ( ) type Store interface { - GetAllAccounts(ctx context.Context) []*Account - GetAccount(ctx context.Context, accountID string) (*Account, error) + GetAllAccounts(ctx context.Context) []*types.Account + GetAccount(ctx context.Context, accountID string) (*types.Account, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) - GetAccountByUser(ctx context.Context, userID string) (*Account, error) - GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) - GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) - GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) - SaveAccount(ctx context.Context, account *Account) error - DeleteAccount(ctx context.Context, account *Account) error + SaveAccount(ctx context.Context, account *types.Account) error + DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) - GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) - SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error - SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error + GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) + SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) - GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) - GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) + GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) + GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) + GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error - SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error + SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error - GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) + GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error - GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) - CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -104,9 +107,12 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) + AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error + RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error + AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) @@ -120,11 +126,11 @@ type Store interface { SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) @@ -137,7 +143,7 @@ type Store interface { GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error @@ -151,31 +157,51 @@ type Store interface { // Close should close the store persisting all unsaved data. Close(ctx context.Context) error - // GetStoreEngine should return StoreEngine of the current store implementation. + // GetStoreEngine should return Engine of the current store implementation. // This is also a method of metrics.DataSource interface. - GetStoreEngine() StoreEngine + GetStoreEngine() Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) + GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) + SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error + DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + + GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) + SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + + GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) + GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) + SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error + DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error } -type StoreEngine string +type Engine string const ( - FileStoreEngine StoreEngine = "jsonfile" - SqliteStoreEngine StoreEngine = "sqlite" - PostgresStoreEngine StoreEngine = "postgres" + FileStoreEngine Engine = "jsonfile" + SqliteStoreEngine Engine = "sqlite" + PostgresStoreEngine Engine = "postgres" + MysqlStoreEngine Engine = "mysql" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" + mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) -func getStoreEngineFromEnv() StoreEngine { +func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") if !ok { return "" } - value := StoreEngine(strings.ToLower(kind)) - if value == SqliteStoreEngine || value == PostgresStoreEngine { + value := Engine(strings.ToLower(kind)) + if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine { return value } @@ -186,7 +212,7 @@ func getStoreEngineFromEnv() StoreEngine { // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { +func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { @@ -212,7 +238,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { @@ -226,12 +252,15 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel case PostgresStoreEngine: log.WithContext(ctx).Info("using Postgres store engine") return newPostgresStore(ctx, metrics) + case MysqlStoreEngine: + log.WithContext(ctx).Info("using MySQL store engine") + return newMysqlStore(ctx, metrics) default: return nil, fmt.Errorf("unsupported kind of store: %s", kind) } } -func checkFileStoreEngine(kind StoreEngine, dataDir string) error { +func checkFileStoreEngine(kind Engine, dataDir string) error { if kind == FileStoreEngine { storeFile := filepath.Join(dataDir, storeFileName) if util.FileExists(storeFile) { @@ -258,7 +287,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") + return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") @@ -273,7 +302,13 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, func(db *gorm.DB) error { - return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db) + }, + func(db *gorm.DB) error { + return migration.MigrateNewField[resourceTypes.NetworkResource](ctx, db, "enabled", true) + }, + func(db *gorm.DB) error { + return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true) }, } } @@ -309,12 +344,13 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( if err != nil { return nil, nil, fmt.Errorf("failed to create test store: %v", err) } - cleanUp := func() { - store.Close(ctx) - } + return getSqlStoreEngine(ctx, store, kind) +} + +func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { if kind == PostgresStoreEngine { - cleanUp, err = testutil.CreatePGDB() + cleanUp, err := testutil.CreatePostgresTestContainer() if err != nil { return nil, nil, err } @@ -328,9 +364,34 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( if err != nil { return nil, nil, err } + + return store, cleanUp, nil } - return store, cleanUp, nil + if kind == MysqlStoreEngine { + cleanUp, err := testutil.CreateMysqlTestContainer() + if err != nil { + return nil, nil, err + } + + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) + } + + store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) + if err != nil { + return nil, nil, err + } + + return store, cleanUp, nil + } + + closeConnection := func() { + store.Close(ctx) + } + + return store, closeConnection, nil } func loadSQL(db *gorm.DB, filepath string) error { diff --git a/management/server/store_test.go b/management/server/store/store_test.go similarity index 93% rename from management/server/store_test.go rename to management/server/store/store_test.go index fc821670d..1d0026e3d 100644 --- a/management/server/store_test.go +++ b/management/server/store/store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } - -func newStore(t *testing.T) Store { - t.Helper() - - store := newSqliteStore(t) - - return store -} diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index e4bb4e3c3..4a5a31e2d 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -13,6 +13,7 @@ type AccountManagerMetrics struct { updateAccountPeersDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram + peerMetaUpdateCount metric.Int64Counter } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, networkMapObjectCount: networkMapObjectCount, + peerMetaUpdateCount: peerMetaUpdateCount, }, nil } @@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } + +// CountPeerMetUpdate counts the number of peer meta updates +func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { + metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) +} diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go index b038c3d36..bb3745b5a 100644 --- a/management/server/telemetry/store_metrics.go +++ b/management/server/telemetry/store_metrics.go @@ -13,6 +13,7 @@ type StoreMetrics struct { globalLockAcquisitionDurationMs metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram persistenceDurationMs metric.Int64Histogram + transactionDurationMs metric.Int64Histogram ctx context.Context } @@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er return nil, err } + transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms") + if err != nil { + return nil, err + } + return &StoreMetrics{ globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMs: persistenceDurationMs, + transactionDurationMs: transactionDurationMs, ctx: ctx, }, nil } @@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) } + +// CountTransactionDuration counts the duration of a store persistence operation +func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) { + metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds()) +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 455111439..2859e82c8 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,10 +26,10 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); -INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); diff --git a/management/server/testdata/networks.sql b/management/server/testdata/networks.sql new file mode 100644 index 000000000..8138ce520 --- /dev/null +++ b/management/server/testdata/networks.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); + +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); + +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description'); + +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999); + +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32'); +INSERT INTO network_resources VALUES('anotherTestResourceId','testNetworkId','testAccountId','used-name','some-description','host','3.3.3.3/32'); diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 2c55e2e31..84524127f 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -1,8 +1,8 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime DEFAULT NULL,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); @@ -12,6 +12,9 @@ CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`type` text,`address` text,PRIMARY KEY (`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); CREATE INDEX `idx_peers_key` ON `peers`(`key`); @@ -24,13 +27,24 @@ CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); +CREATE INDEX `idx_network_routers_id` ON `network_routers`(`id`); +CREATE INDEX `idx_network_routers_account_id` ON `network_routers`(`account_id`); +CREATE INDEX `idx_network_routers_network_id` ON `network_routers`(`network_id`); +CREATE INDEX `idx_network_resources_account_id` ON `network_resources`(`account_id`); +CREATE INDEX `idx_network_resources_network_id` ON `network_resources`(`network_id`); +CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`); +CREATE INDEX `idx_networks_id` ON `networks`(`id`); +CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); +INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); +INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); +INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index a9360e9d6..a88411795 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,10 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'[]',0,0); INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4h','bf1c8084-ba50-4ce7-9439-34653001fc3b','groupA','api','',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 54b946b5a..5990a0625 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,11 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'[]',0,0); INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index 281fdac8a..cda333d4f 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -31,9 +31,9 @@ INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); -INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); -INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',1,'""','','',0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',1,'""','','',0); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 156a762fb..16438cab8 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -10,36 +10,75 @@ import ( log "github.com/sirupsen/logrus" "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" ) -func CreatePGDB() (func(), error) { +// CreateMysqlTestContainer creates a new MySQL container for testing. +func CreateMysqlTestContainer() (func(), error) { ctx := context.Background() - c, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:alpine"), - postgres.WithDatabase("test"), - postgres.WithUsername("postgres"), - postgres.WithPassword("postgres"), + + myContainer, err := mysql.RunContainer(ctx, + testcontainers.WithImage("mlsmaycon/warmed-mysql:8"), + mysql.WithDatabase("testing"), + mysql.WithUsername("testing"), + mysql.WithPassword("testing"), testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2).WithStartupTimeout(15*time.Second)), + wait.ForLog("/usr/sbin/mysqld: ready for connections"). + WithOccurrence(1).WithStartupTimeout(15*time.Second).WithPollInterval(100*time.Millisecond), + ), ) if err != nil { return nil, err } cleanup := func() { - timeout := 10 * time.Second - err = c.Stop(ctx, &timeout) - if err != nil { - log.WithContext(ctx).Warnf("failed to stop container: %s", err) + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = myContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", myContainer.GetContainerID(), err) } } - talksConn, err := c.ConnectionString(ctx) + talksConn, err := myContainer.ConnectionString(ctx) if err != nil { - return cleanup, err + return nil, err } + + return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_MYSQL_DSN", talksConn) +} + +// CreatePostgresTestContainer creates a new PostgreSQL container for testing. +func CreatePostgresTestContainer() (func(), error) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:16-alpine"), + postgres.WithDatabase("netbird"), + postgres.WithUsername("root"), + postgres.WithPassword("netbird"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(15*time.Second), + ), + ) + if err != nil { + return nil, err + } + + cleanup := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = pgContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop postgres container %s: %s", pgContainer.GetContainerID(), err) + } + } + + talksConn, err := pgContainer.ConnectionString(ctx) + if err != nil { + return nil, err + } + return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) } diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index af2cf7a3f..edde62f1e 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -3,4 +3,14 @@ package testutil -func CreatePGDB() (func(), error) { return func() {}, nil } +func CreatePostgresTestContainer() (func(), error) { + return func() { + // Empty function for Postgres + }, nil +} + +func CreateMysqlTestContainer() (func(), error) { + return func() { + // Empty function for MySQL + }, nil +} diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go index ef8276b59..fd67fa3e3 100644 --- a/management/server/token_mgr.go +++ b/management/server/token_mgr.go @@ -158,7 +158,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, pee log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) return case <-ticker.C: - m.pushNewTURNTokens(ctx, peerID) + m.pushNewTURNAndRelayTokens(ctx, peerID) } } } @@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, pe } } -func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, peerID string) { +func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Context, peerID string) { turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) if err != nil { log.Errorf("failed to generate token for peer '%s': %s", peerID, err) @@ -201,10 +201,21 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee update := &proto.SyncResponse{ WiretrusteeConfig: &proto.WiretrusteeConfig{ Turns: turns, - // omit Relay to avoid updates there }, } + // workaround for the case when client is unable to handle turn and relay updates at different time + if m.relayCfg != nil { + token, err := m.GenerateRelayToken() + if err == nil { + update.WiretrusteeConfig.Relay = &proto.RelayConfig{ + Urls: m.relayCfg.Addresses, + TokenPayload: token.Payload, + TokenSignature: token.Signature, + } + } + } + log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) } diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index 3e63346c2..2aafb9f68 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -133,11 +133,14 @@ loop: } } if relay := update.Update.GetWiretrusteeConfig().GetRelay(); relay != nil { - relayUpdates++ - if relayUpdates == 1 { - firstRelayUpdate = relay - } else { - secondRelayUpdate = relay + // avoid updating on turn updates since they also send relay credentials + if update.Update.GetWiretrusteeConfig().GetTurns() == nil { + relayUpdates++ + if relayUpdates == 1 { + firstRelayUpdate = relay + } else { + secondRelayUpdate = relay + } } } } diff --git a/management/server/types/account.go b/management/server/types/account.go new file mode 100644 index 000000000..f74d38cb6 --- /dev/null +++ b/management/server/types/account.go @@ -0,0 +1,1512 @@ +package types + +import ( + "context" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/route" +) + +const ( + defaultTTL = 300 + DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" +) + +type LookupMap map[string]struct{} + +// Account represents a unique account of the system +type Account struct { + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + + // User.Id it was created by + CreatedBy string + CreatedAt time.Time + Domain string `gorm:"index"` + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*nbpeer.Peer `gorm:"-"` + PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[route.ID]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` + PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` + // Settings is a dictionary of Account settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` + + Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` + NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` + NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` +} + +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + +// Subclass used in gorm to only load settings and not whole account +type AccountSettings struct { + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} + +// GetRoutesToSync returns the enabled routes for the peer ID and the routes +// from the ACL peers that have distribution groups associated with the peer ID. +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + groupListMap := a.GetPeerGroups(peerID) + for _, peer := range aclPeers { + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) + filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership +func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map +func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + for _, r := range a.Routes { + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID { + takeRoute(r.Copy(), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { + var routes []*route.Route + for _, r := range a.Routes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes +} + +// GetGroup returns a group by ID if exists, nil otherwise +func (a *Account) GetGroup(groupID string) *Group { + return a.Groups[groupID] +} + +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMap( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + + peer := a.Peers[peerID] + if peer == nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + // exclude expired peers + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) + isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) + } + peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + nm := &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: a.Network.Copy(), + Routes: slices.Concat(networkResourcesRoutes, routesUpdate), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + } + + if metrics != nil { + objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", + a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) + } + } + + return nm +} + +func (a *Account) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peer *nbpeer.Peer, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := a.Peers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { + groupList := account.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +// peerIsNameserver returns true if the peer is a nameserver for a nsGroup +func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peer.IP.Equal(ns.IP.AsSlice()) { + return true + } + } + return false +} + +func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { + for _, peer := range account.Peers { + label, err := GetPeerHostLabel(peer.Name, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + label, err = GetPeerHostLabel(peer.Meta.Hostname, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + continue + } + } + peer.DNSLabel = label + peerLabels[label] = struct{}{} + } +} + +func GetPeerHostLabel(name string, peerLabels LookupMap) (string, error) { + label, err := nbdns.GetParsedDomainLabel(name) + if err != nil { + return "", err + } + + uniqueLabel := getUniqueHostLabel(label, peerLabels) + if uniqueLabel == "" { + return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) + } + return uniqueLabel, nil +} + +// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 +func getUniqueHostLabel(name string, peerLabels LookupMap) string { + _, found := peerLabels[name] + if !found { + return name + } + for i := 1; i < 1000; i++ { + nameWithSuffix := name + "-" + strconv.Itoa(i) + _, found = peerLabels[nameWithSuffix] + if !found { + return nameWithSuffix + } + } + return "" +} + +func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { + var merr *multierror.Error + + if dnsDomain == "" { + log.WithContext(ctx).Error("no dns domain is set, returning empty zone") + return nbdns.CustomZone{} + } + + customZone := nbdns.CustomZone{ + Domain: dns.Fqdn(dnsDomain), + Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), + } + + domainSuffix := "." + dnsDomain + + var sb strings.Builder + for _, peer := range a.Peers { + if peer.DNSLabel == "" { + merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) + continue + } + + sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) + sb.WriteString(peer.DNSLabel) + sb.WriteString(domainSuffix) + + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: sb.String(), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IP.String(), + }) + + sb.Reset() + } + + go func() { + if merr != nil { + log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) + } + }() + + return customZone +} + +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetPeers returns a list of all Account peers +func (a *Account) GetPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.Peers { + peers = append(peers, peer) + } + return peers +} + +// UpdateSettings saves new account settings +func (a *Account) UpdateSettings(update *Settings) *Account { + a.Settings = update.Copy() + return a +} + +// UpdatePeer saves new or replaces existing peer +func (a *Account) UpdatePeer(update *nbpeer.Peer) { + a.Peers[update.ID] = update +} + +// DeletePeer deletes peer from the account cleaning up all the references +func (a *Account) DeletePeer(peerID string) { + // delete peer from groups + for _, g := range a.Groups { + for i, pk := range g.Peers { + if pk == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + break + } + } + } + + for _, r := range a.Routes { + if r.Peer == peerID { + r.Enabled = false + r.Peer = "" + } + } + + for i, r := range a.NetworkRouters { + if r.Peer == peerID { + a.NetworkRouters = append(a.NetworkRouters[:i], a.NetworkRouters[i+1:]...) + break + } + } + + delete(a.Peers, peerID) + a.Network.IncSerial() +} + +func (a *Account) DeleteResource(resourceID string) { + // delete resource from groups + for _, g := range a.Groups { + for i, pk := range g.Resources { + if pk.ID == resourceID { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + break + } + } + } +} + +// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. +// It will return an object copy of the peer. +func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { + for _, peer := range a.Peers { + if peer.Key == peerPubKey { + return peer.Copy(), nil + } + } + + return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) +} + +// FindUserPeers returns a list of peers that user owns (created) +func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.UserID == userID { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// FindUser looks for a given user in the Account or returns error if user wasn't found. +func (a *Account) FindUser(userID string) (*User, error) { + user := a.Users[userID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user %s not found", userID) + } + + return user, nil +} + +// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. +func (a *Account) FindGroupByName(groupName string) (*Group, error) { + for _, group := range a.Groups { + if group.Name == groupName { + return group, nil + } + } + return nil, status.Errorf(status.NotFound, "group %s not found", groupName) +} + +// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. +func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { + key := a.SetupKeys[setupKey] + if key == nil { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return key, nil +} + +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + +func (a *Account) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := a.GetPeerGroups(peerID) + enabled := true + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (a *Account) GetPeerGroups(peerID string) LookupMap { + groupList := make(LookupMap) + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + groupList[groupID] = struct{}{} + break + } + } + } + return groupList +} + +func (a *Account) GetTakenIPs() []net.IP { + var takenIps []net.IP + for _, existingPeer := range a.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps +} + +func (a *Account) GetPeerDNSLabels() LookupMap { + existingLabels := make(LookupMap) + for _, peer := range a.Peers { + if peer.DNSLabel != "" { + existingLabels[peer.DNSLabel] = struct{}{} + } + } + return existingLabels +} + +func (a *Account) Copy() *Account { + peers := map[string]*nbpeer.Peer{} + for id, peer := range a.Peers { + peers[id] = peer.Copy() + } + + users := map[string]*User{} + for id, user := range a.Users { + users[id] = user.Copy() + } + + setupKeys := map[string]*SetupKey{} + for id, key := range a.SetupKeys { + setupKeys[id] = key.Copy() + } + + groups := map[string]*Group{} + for id, group := range a.Groups { + groups[id] = group.Copy() + } + + policies := []*Policy{} + for _, policy := range a.Policies { + policies = append(policies, policy.Copy()) + } + + routes := map[route.ID]*route.Route{} + for id, r := range a.Routes { + routes[id] = r.Copy() + } + + nsGroups := map[string]*nbdns.NameServerGroup{} + for id, nsGroup := range a.NameServerGroups { + nsGroups[id] = nsGroup.Copy() + } + + dnsSettings := a.DNSSettings.Copy() + + var settings *Settings + if a.Settings != nil { + settings = a.Settings.Copy() + } + + postureChecks := []*posture.Checks{} + for _, postureCheck := range a.PostureChecks { + postureChecks = append(postureChecks, postureCheck.Copy()) + } + + nets := []*networkTypes.Network{} + for _, network := range a.Networks { + nets = append(nets, network.Copy()) + } + + networkRouters := []*routerTypes.NetworkRouter{} + for _, router := range a.NetworkRouters { + networkRouters = append(networkRouters, router.Copy()) + } + + networkResources := []*resourceTypes.NetworkResource{} + for _, resource := range a.NetworkResources { + networkResources = append(networkResources, resource.Copy()) + } + + return &Account{ + Id: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, + SetupKeys: setupKeys, + Network: a.Network.Copy(), + Peers: peers, + Users: users, + Groups: groups, + Policies: policies, + Routes: routes, + NameServerGroups: nsGroups, + DNSSettings: dnsSettings, + PostureChecks: postureChecks, + Settings: settings, + Networks: nets, + NetworkRouters: networkRouters, + NetworkResources: networkResources, + } +} + +func (a *Account) GetGroupAll() (*Group, error) { + for _, g := range a.Groups { + if g.Name == "All" { + return g, nil + } + } + return nil, fmt.Errorf("no group ALL found") +} + +// GetPeer looks up a Peer by ID +func (a *Account) GetPeer(peerID string) *nbpeer.Peer { + return a.Peers[peerID] +} + +// UserGroupsAddToPeers adds groups to all peers of user +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + userPeers := make(map[string]struct{}) + for pid, peer := range a.Peers { + if peer.UserID == userID { + userPeers[pid] = struct{}{} + } + } + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok { + continue + } + + oldPeers := group.Peers + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeers { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupUpdates[gid] = util.Difference(group.Peers, oldPeers) + } + + return groupUpdates +} + +// UserGroupsRemoveFromPeers removes groups from all peers of user +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok || group.Name == "All" { + continue + } + + oldPeers := group.Peers + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok { + continue + } + if peer.UserID != userID { + update = append(update, pid) + } + } + group.Peers = update + groupUpdates[gid] = util.Difference(oldPeers, group.Peers) + } + + return groupUpdates +} + +// GetPeerConnectionResources for a given peer +// +// This function returns the list of peers and firewall rules that are applicable to a given peer. +func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls +// +// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. +// It safe to call the generator function multiple times for same peer and different rules no duplicates will be +// generated. The accumulator function returns the result of all the generator calls. +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.WithContext(ctx).Errorf("failed to get group all: %v", err) + all = &Group{} + } + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 { + rules = append(rules, &fr) + continue + } + + for _, port := range rule.Ports { + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) + } + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +// getAllPeersFromGroups for given peer ID and list of groups +// +// Returns a list of peers from specified groups that pass specified posture checks +// and a boolean indicating if the supplied peer ID exists within these groups. +// +// Important: Posture checks are applicable only to source group peers, +// for destination group peers, call this method with an empty list of sourcePostureChecksIDs +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { + peerInGroups := false + uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + for _, p := range uniquePeerIDs { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } + + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) + } + + return filteredPeers, peerInGroups +} + +// validatePostureChecksOnPeer validates the posture checks on a peer +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return false + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, err := check.Check(ctx, *peer) + if err != nil { + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + } + if !isValid { + return false + } + } + } + return true +} + +func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { + for _, postureChecks := range a.PostureChecks { + if postureChecks.ID == postureChecksID { + return postureChecks + } + } + return nil +} + +// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + for _, route := range routes { + if route.Peer != peer.Key { + continue + } + resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) + + rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + } + + return routesFirewallRules +} + +// getNetworkResourceGroups retrieves all groups associated with the given network resource. +func (a *Account) getNetworkResourceGroups(resourceID string) []*Group { + var networkResourceGroups []*Group + + for _, group := range a.Groups { + for _, resource := range group.Resources { + if resource.ID == resourceID { + networkResourceGroups = append(networkResourceGroups, group) + } + } + } + + return networkResourceGroups +} + +// GetResourcePoliciesMap returns a map of networks resource IDs and their associated policies. +func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { + resourcePolicies := make(map[string][]*Policy) + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + resourcePolicies[resource.ID] = resourceAppliedPolicies + } + return resourcePolicies +} + +// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}, len(a.Peers)) + + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + var addSourcePeers bool + + networkRoutingPeers, exists := routers[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) + } + } + + addedResourceRoute := false + for _, policy := range resourcePolicies[resource.ID] { + peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + if addSourcePeers { + for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + // add routes for the resource if the peer is in the distribution group + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) + } + addedResourceRoute = true + } + if addedResourceRoute { + break + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { + dest = append(dest, peerID) + } + } + return dest +} + +func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity + for _, groupID := range groups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id) + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } + } + + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) + } + + return ids +} + +// getNetworkResources filters and returns a list of network resources associated with the given network ID. +func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { + var resources []*resourceTypes.NetworkResource + for _, resource := range a.NetworkResources { + if resource.NetworkID == networkID { + resources = append(resources, resource) + } + } + return resources +} + +// GetPoliciesForNetworkResource retrieves the list of policies that apply to a specific network resource. +// A policy is deemed applicable if its destination groups include any of the given network resource groups +// or if its destination resource explicitly matches the provided resource. +func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { + var resourceAppliedPolicies []*Policy + + networkResourceGroups := a.getNetworkResourceGroups(resourceId) + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if rule.DestinationResource.ID == resourceId { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + + for _, group := range networkResourceGroups { + if slices.Contains(rule.Destinations, group.ID) { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + } + } + } + + return resourceAppliedPolicies +} + +func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { + networkResources := a.getNetworkResources(networkID) + + policiesIDs := map[string]struct{}{} + for _, resource := range networkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + for _, policy := range resourceAppliedPolicies { + policiesIDs[policy.ID] = struct{}{} + } + } + + result := make([]string, 0, len(policiesIDs)) + for id := range policiesIDs { + result = append(result, id) + } + + return result +} + +// getNetworkResourcesRoutes convert the network resources list to routes list. +func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { + resourceAppliedPolicies := resourcePolicies[resource.ID] + + var routes []*route.Route + // distribute the resource routes only if there is policy applied to it + if len(resourceAppliedPolicies) > 0 { + peer := a.GetPeer(peerId) + if peer != nil { + routes = append(routes, resource.ToRoute(peer, router)) + } + } + + return routes +} + +func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { + routers := make(map[string]map[string]*routerTypes.NetworkRouter) + + for _, router := range a.NetworkRouters { + if !router.Enabled { + continue + } + + if routers[router.NetworkID] == nil { + routers[router.NetworkID] = make(map[string]*routerTypes.NetworkRouter) + } + + if router.Peer != "" { + routers[router.NetworkID][router.Peer] = router + continue + } + + for _, peerGroup := range router.PeerGroups { + g := a.Groups[peerGroup] + if g != nil { + for _, peerID := range g.Peers { + routers[router.NetworkID][peerID] = router + } + } + } + } + + return routers +} + +// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies. +func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := groups[sourceGroup] + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + } + } + + return sourcePeers +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go new file mode 100644 index 000000000..f8ab1d627 --- /dev/null +++ b/management/server/types/account_test.go @@ -0,0 +1,837 @@ +package types + +import ( + "context" + "net" + "net/netip" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" +) + +func setupTestAccount() *Account { + return &Account{ + Id: "accountID", + Peers: map[string]*nbpeer.Peer{ + "peer1": { + ID: "peer1", + AccountID: "accountID", + Key: "peer1Key", + }, + "peer2": { + ID: "peer2", + AccountID: "accountID", + Key: "peer2Key", + }, + "peer3": { + ID: "peer3", + AccountID: "accountID", + Key: "peer3Key", + }, + "peer11": { + ID: "peer11", + AccountID: "accountID", + Key: "peer11Key", + }, + "peer12": { + ID: "peer12", + AccountID: "accountID", + Key: "peer12Key", + }, + "peer21": { + ID: "peer21", + AccountID: "accountID", + Key: "peer21Key", + }, + "peer31": { + ID: "peer31", + AccountID: "accountID", + Key: "peer31Key", + }, + "peer32": { + ID: "peer32", + AccountID: "accountID", + Key: "peer32Key", + }, + "peer41": { + ID: "peer41", + AccountID: "accountID", + Key: "peer41Key", + }, + "peer51": { + ID: "peer51", + AccountID: "accountID", + Key: "peer51Key", + }, + "peer61": { + ID: "peer61", + AccountID: "accountID", + Key: "peer61Key", + }, + }, + Groups: map[string]*Group{ + "group1": { + ID: "group1", + Peers: []string{"peer11", "peer12"}, + Resources: []Resource{ + { + ID: "resource1ID", + Type: "Host", + }, + }, + }, + "group2": { + ID: "group2", + Peers: []string{"peer21"}, + Resources: []Resource{ + { + ID: "resource2ID", + Type: "Domain", + }, + }, + }, + "group3": { + ID: "group3", + Peers: []string{"peer31", "peer32"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group4": { + ID: "group4", + Peers: []string{"peer41"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group5": { + ID: "group5", + Peers: []string{"peer51"}, + }, + "group6": { + ID: "group6", + Peers: []string{"peer61"}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1ID", + AccountID: "accountID", + Name: "network1", + }, + { + ID: "network2ID", + AccountID: "accountID", + Name: "network2", + }, + { + ID: "network3ID", + AccountID: "accountID", + Name: "network3", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer1", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router2ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "peer2", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router3ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer3", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router4ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router5ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group2", "group3"}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router6ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group4"}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + { + ID: "router6ID", + NetworkID: "network3ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group6"}, + Masquerade: false, + Metric: 100, + Enabled: false, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1ID", + AccountID: "accountID", + NetworkID: "network1ID", + Enabled: true, + }, + { + ID: "resource2ID", + AccountID: "accountID", + NetworkID: "network2ID", + Enabled: true, + }, + { + ID: "resource3ID", + AccountID: "accountID", + NetworkID: "network1ID", + Enabled: true, + }, + { + ID: "resource4ID", + AccountID: "accountID", + NetworkID: "network1ID", + Enabled: true, + }, + { + ID: "resource5ID", + AccountID: "accountID", + NetworkID: "network3ID", + Enabled: false, + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Destinations: []string{"group1"}, + }, + }, + }, + { + ID: "policy2ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule2ID", + Enabled: true, + Destinations: []string{"group3"}, + }, + }, + }, + { + ID: "policy3ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule3ID", + Enabled: true, + Destinations: []string{"group2", "group4"}, + }, + }, + }, + { + ID: "policy4ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule4ID", + Enabled: true, + DestinationResource: Resource{ + ID: "resource4ID", + Type: "Host", + }, + }, + }, + }, + { + ID: "policy5ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule5ID", + Enabled: true, + }, + }, + }, + { + ID: "policy6ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule6ID", + Enabled: true, + }, + }, + }, + }, + } +} + +func Test_GetResourceRoutersMap(t *testing.T) { + account := setupTestAccount() + routers := account.GetResourceRoutersMap() + require.Equal(t, 2, len(routers)) + + require.Equal(t, 7, len(routers["network1ID"])) + require.NotNil(t, routers["network1ID"]["peer1"]) + require.NotNil(t, routers["network1ID"]["peer3"]) + require.NotNil(t, routers["network1ID"]["peer11"]) + require.NotNil(t, routers["network1ID"]["peer12"]) + require.NotNil(t, routers["network1ID"]["peer21"]) + require.NotNil(t, routers["network1ID"]["peer31"]) + require.NotNil(t, routers["network1ID"]["peer32"]) + + require.Equal(t, 2, len(routers["network2ID"])) + require.NotNil(t, routers["network2ID"]["peer2"]) + require.NotNil(t, routers["network2ID"]["peer41"]) + + require.Equal(t, 0, len(routers["network3ID"])) +} + +func Test_GetResourcePoliciesMap(t *testing.T) { + account := setupTestAccount() + policies := account.GetResourcePoliciesMap() + require.Equal(t, 4, len(policies)) + require.Equal(t, 1, len(policies["resource1ID"])) + require.Equal(t, 1, len(policies["resource2ID"])) + require.Equal(t, 2, len(policies["resource3ID"])) + require.Equal(t, 1, len(policies["resource4ID"])) + require.Equal(t, 0, len(policies["resource5ID"])) +} + +func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key", ID: "peer2"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer4Key", ID: "peer4"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) + require.Len(t, result, 2) + require.Equal(t, "peer2Key", result[0].Key) + require.Equal(t, "peer3Key", result[1].Key) +} + +func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key", ID: "peer2"}, + } + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key", ID: "peer2"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer3Key", ID: "peer3"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersExcludesSelf(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer1Key", PeerID: "peer1"}, + {Peer: "peer2Key", PeerID: "peer2"}, + } + peersToConnect := []*nbpeer.Peer{} + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, true, map[string]struct{}{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{} + peersToConnect := []*nbpeer.Peer{} + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) + require.Len(t, result, 0) +} + +const ( + accID = "accountID" + network1ID = "network1ID" + group1ID = "group1" + accNetResourcePeer1ID = "peer1" + accNetResourcePeer2ID = "peer2" + accNetResourceRouter1ID = "router1" + accNetResource1ID = "resource1ID" + accNetResourceRestrictPostureCheckID = "restrictPostureCheck" + accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" + accNetResourceLockedPostureCheckID = "lockedPostureCheck" + accNetResourceLinuxPostureCheckID = "linuxPostureCheck" +) + +var ( + accNetResourcePeer1IP = net.IP{192, 168, 1, 1} + accNetResourcePeer2IP = net.IP{192, 168, 1, 2} + accNetResourceRouter1IP = net.IP{192, 168, 1, 3} + accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} +) + +func getBasicAccountsWithResource() *Account { + return &Account{ + Id: accID, + Peers: map[string]*nbpeer.Peer{ + accNetResourcePeer1ID: { + ID: accNetResourcePeer1ID, + AccountID: accID, + Key: "peer1Key", + IP: accNetResourcePeer1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourcePeer2ID: { + ID: accNetResourcePeer2ID, + AccountID: accID, + Key: "peer2Key", + IP: accNetResourcePeer2IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "windows", + WtVersion: "0.34.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourceRouter1ID: { + ID: accNetResourceRouter1ID, + AccountID: accID, + Key: "router1Key", + IP: accNetResourceRouter1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + }, + Groups: map[string]*Group{ + group1ID: { + ID: group1ID, + Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: network1ID, + AccountID: accID, + Name: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: accNetResourceRouter1ID, + NetworkID: network1ID, + AccountID: accID, + Peer: accNetResourceRouter1ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + Enabled: true, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: accNetResource1ID, + AccountID: accID, + NetworkID: network1ID, + Address: "10.10.10.0/24", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Type: resourceTypes.NetworkResourceType("subnet"), + Enabled: true, + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: nil, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: accNetResourceRestrictPostureCheckID, + Name: accNetResourceRestrictPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.35.0", + }, + }, + }, + { + ID: accNetResourceRelaxedPostureCheckID, + Name: accNetResourceRelaxedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + { + ID: accNetResourceLockedPostureCheckID, + Name: accNetResourceLockedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "7.7.7", + }, + }, + }, + { + ID: accNetResourceLinuxPostureCheckID, + Name: accNetResourceLinuxPostureCheckID, + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "0.0.0"}, + }, + }, + }, + }, + } +} + +func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // all peers should match the policy + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should not match any peer + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 0, "expected rules count don't match") +} + +func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // should allow peer1 and peer2 to match the policy + newPolicy := &Policy{ + ID: "policy2ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "policy2ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, + } + + account.Policies = append(account.Policies, newPolicy) + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 2, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } + + assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") + assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // two posture checks should match only the peers that match both checks + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { + account := getBasicAccountsWithResource() + + account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}} + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router2Id", + NetworkID: network1ID, + AccountID: accID, + Peer: "router2Id", + }) + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") +} diff --git a/management/server/types/dns_settings.go b/management/server/types/dns_settings.go new file mode 100644 index 000000000..1d33bb9fb --- /dev/null +++ b/management/server/types/dns_settings.go @@ -0,0 +1,16 @@ +package types + +// DNSSettings defines dns settings at the account level +type DNSSettings struct { + // DisabledManagementGroups groups whose DNS management is disabled + DisabledManagementGroups []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the DNS settings +func (d DNSSettings) Copy() DNSSettings { + settings := DNSSettings{ + DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), + } + copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) + return settings +} diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go new file mode 100644 index 000000000..4e405152c --- /dev/null +++ b/management/server/types/firewall_rule.go @@ -0,0 +1,139 @@ +package types + +import ( + "context" + "fmt" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" +) + +const ( + FirewallRuleDirectionIN = 0 + FirewallRuleDirectionOUT = 1 +) + +// FirewallRule is a rule of the firewall. +type FirewallRule struct { + // PeerIP of the peer + PeerIP string + + // Direction of the traffic + Direction int + + // Action of the traffic + Action string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port string +} + +// IsEqual checks if two firewall rules are equal. +func (r *FirewallRule) IsEqual(other *FirewallRule) bool { + return r.PeerIP == other.PeerIP && + r.Direction == other.Direction && + r.Action == other.Action && + r.Protocol == other.Protocol && + r.Port == other.Port +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} diff --git a/management/server/group/group.go b/management/server/types/group.go similarity index 58% rename from management/server/group/group.go rename to management/server/types/group.go index 24c60d3ce..00a28fa77 100644 --- a/management/server/group/group.go +++ b/management/server/types/group.go @@ -1,6 +1,9 @@ -package group +package types -import "github.com/netbirdio/netbird/management/server/integration_reference" +import ( + "github.com/netbirdio/netbird/management/server/integration_reference" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) const ( GroupIssuedAPI = "api" @@ -25,6 +28,9 @@ type Group struct { // Peers list of the group Peers []string `gorm:"serializer:json"` + // Resources contains a list of resources in that group + Resources []Resource `gorm:"serializer:json"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } @@ -33,15 +39,21 @@ func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} } +func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]any { + return map[string]any{"name": g.Name, "id": g.ID, "resource_name": resource.Name, "resource_id": resource.ID, "resource_type": resource.Type} +} + func (g *Group) Copy() *Group { group := &Group{ ID: g.ID, Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.Resources, g.Resources) return group } @@ -81,3 +93,31 @@ func (g *Group) RemovePeer(peerID string) bool { } return false } + +// AddResource adds resource to Resources if not present, returning true if added. +func (g *Group) AddResource(resource Resource) bool { + for _, item := range g.Resources { + if item == resource { + return false + } + } + + g.Resources = append(g.Resources, resource) + return true +} + +// RemoveResource removes resource from Resources if present, returning true if removed. +func (g *Group) RemoveResource(resource Resource) bool { + for i, item := range g.Resources { + if item == resource { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + return true + } + } + return false +} + +// HasResources checks if the group has any resources. +func (g *Group) HasResources() bool { + return len(g.Resources) > 0 +} diff --git a/management/server/group/group_test.go b/management/server/types/group_test.go similarity index 99% rename from management/server/group/group_test.go rename to management/server/types/group_test.go index cb002f8d9..12107c603 100644 --- a/management/server/group/group_test.go +++ b/management/server/types/group_test.go @@ -1,4 +1,4 @@ -package group +package types import ( "testing" diff --git a/management/server/network.go b/management/server/types/network.go similarity index 96% rename from management/server/network.go rename to management/server/types/network.go index a5b188b46..d1fccd149 100644 --- a/management/server/network.go +++ b/management/server/types/network.go @@ -1,4 +1,4 @@ -package server +package types import ( "math/rand" @@ -43,7 +43,7 @@ type Network struct { // Used to synchronize state to the client apps. Serial uint64 - mu sync.Mutex `json:"-" gorm:"-"` + Mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 @@ -66,15 +66,15 @@ func NewNetwork() *Network { // IncSerial increments Serial by 1 reflecting that the network state has been changed func (n *Network) IncSerial() { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() n.Serial++ } // CurrentSerial returns the Network.Serial of the network (latest state id) func (n *Network) CurrentSerial() uint64 { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() return n.Serial } diff --git a/management/server/network_test.go b/management/server/types/network_test.go similarity index 98% rename from management/server/network_test.go rename to management/server/types/network_test.go index b067c4991..d0b0894d4 100644 --- a/management/server/network_test.go +++ b/management/server/types/network_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "net" diff --git a/management/server/personal_access_token.go b/management/server/types/personal_access_token.go similarity index 82% rename from management/server/personal_access_token.go rename to management/server/types/personal_access_token.go index e4b19da76..0aa6b152b 100644 --- a/management/server/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" @@ -8,6 +8,7 @@ import ( "time" b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" "github.com/netbirdio/netbird/base62" @@ -31,11 +32,11 @@ type PersonalAccessToken struct { UserID string `gorm:"index"` Name string HashedToken string - ExpirationDate time.Time + ExpirationDate *time.Time // scope could be added in future CreatedBy string CreatedAt time.Time - LastUsed time.Time + LastUsed *time.Time } func (t *PersonalAccessToken) Copy() *PersonalAccessToken { @@ -50,6 +51,22 @@ func (t *PersonalAccessToken) Copy() *PersonalAccessToken { } } +// GetExpirationDate returns the expiration time of the token. +func (t *PersonalAccessToken) GetExpirationDate() time.Time { + if t.ExpirationDate != nil { + return *t.ExpirationDate + } + return time.Time{} +} + +// GetLastUsed returns the last time the token was used. +func (t *PersonalAccessToken) GetLastUsed() time.Time { + if t.LastUsed != nil { + return *t.LastUsed + } + return time.Time{} +} + // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it type PersonalAccessTokenGenerated struct { PlainToken string @@ -70,10 +87,9 @@ func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) UserID: targetID, Name: name, HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)), CreatedBy: createdBy, CreatedAt: currentTime, - LastUsed: time.Time{}, }, PlainToken: plainToken, }, nil diff --git a/management/server/personal_access_token_test.go b/management/server/types/personal_access_token_test.go similarity index 98% rename from management/server/personal_access_token_test.go rename to management/server/types/personal_access_token_test.go index 311ffd9cf..ac3377151 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/types/personal_access_token_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/types/policy.go b/management/server/types/policy.go new file mode 100644 index 000000000..17964ed1f --- /dev/null +++ b/management/server/types/policy.go @@ -0,0 +1,136 @@ +package types + +const ( + // PolicyTrafficActionAccept indicates that the traffic is accepted + PolicyTrafficActionAccept = PolicyTrafficActionType("accept") + // PolicyTrafficActionDrop indicates that the traffic is dropped + PolicyTrafficActionDrop = PolicyTrafficActionType("drop") +) + +const ( + // PolicyRuleProtocolALL type of traffic + PolicyRuleProtocolALL = PolicyRuleProtocolType("all") + // PolicyRuleProtocolTCP type of traffic + PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") + // PolicyRuleProtocolUDP type of traffic + PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") + // PolicyRuleProtocolICMP type of traffic + PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") +) + +const ( + // PolicyRuleFlowDirect allows traffic from source to destination + PolicyRuleFlowDirect = PolicyRuleDirection("direct") + // PolicyRuleFlowBidirect allows traffic to both directions + PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") +) + +const ( + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + +// PolicyUpdateOperation operation object with type and values to be applied +type PolicyUpdateOperation struct { + Type PolicyUpdateOperationType + Values []string +} + +// Policy of the Rego query +type Policy struct { + // ID of the policy' + ID string `gorm:"primaryKey"` + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name of the Policy + Name string + + // Description of the policy visible in the UI + Description string + + // Enabled status of the policy + Enabled bool + + // Rules of the policy + Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` + + // SourcePostureChecks are ID references to Posture checks for policy source groups + SourcePostureChecks []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the policy. +func (p *Policy) Copy() *Policy { + c := &Policy{ + ID: p.ID, + AccountID: p.AccountID, + Name: p.Name, + Description: p.Description, + Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), + SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), + } + for i, r := range p.Rules { + c.Rules[i] = r.Copy() + } + copy(c.SourcePostureChecks, p.SourcePostureChecks) + return c +} + +// EventMeta returns activity event meta related to this policy +func (p *Policy) EventMeta() map[string]any { + return map[string]any{"name": p.Name} +} + +// UpgradeAndFix different version of policies to latest version +func (p *Policy) UpgradeAndFix() { + for _, r := range p.Rules { + // start migrate from version v0.20.3 + if r.Protocol == "" { + r.Protocol = PolicyRuleProtocolALL + } + if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { + r.Bidirectional = true + } + // -- v0.20.4 + } +} + +// RuleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) RuleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + +// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. +func (p *Policy) SourceGroups() []string { + if len(p.Rules) == 1 { + return p.Rules[0].Sources + } + groups := make(map[string]struct{}, len(p.Rules)) + for _, rule := range p.Rules { + for _, source := range rule.Sources { + groups[source] = struct{}{} + } + } + + groupIDs := make([]string, 0, len(groups)) + for groupID := range groups { + groupIDs = append(groupIDs, groupID) + } + + return groupIDs +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go new file mode 100644 index 000000000..bd9a99292 --- /dev/null +++ b/management/server/types/policyrule.go @@ -0,0 +1,87 @@ +package types + +// PolicyUpdateOperationType operation type +type PolicyUpdateOperationType int + +// PolicyTrafficActionType action type for the firewall +type PolicyTrafficActionType string + +// PolicyRuleProtocolType type of traffic +type PolicyRuleProtocolType string + +// PolicyRuleDirection direction of traffic +type PolicyRuleDirection string + +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + +// PolicyRule is the metadata of the policy +type PolicyRule struct { + // ID of the policy rule + ID string `gorm:"primaryKey"` + + // PolicyID is a reference to Policy that this object belongs + PolicyID string `json:"-" gorm:"index"` + + // Name of the rule visible in the UI + Name string + + // Description of the rule visible in the UI + Description string + + // Enabled status of rule in the system + Enabled bool + + // Action policy accept or drops packets + Action PolicyTrafficActionType + + // Destinations policy destination groups + Destinations []string `gorm:"serializer:json"` + + // DestinationResource policy destination resource that the rule is applied to + DestinationResource Resource `gorm:"serializer:json"` + + // Sources policy source groups + Sources []string `gorm:"serializer:json"` + + // SourceResource policy source resource that the rule is applied to + SourceResource Resource `gorm:"serializer:json"` + + // Bidirectional define if the rule is applicable in both directions, sources, and destinations + Bidirectional bool + + // Protocol type of the traffic + Protocol PolicyRuleProtocolType + + // Ports or it ranges list + Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` +} + +// Copy returns a copy of a policy rule +func (pm *PolicyRule) Copy() *PolicyRule { + rule := &PolicyRule{ + ID: pm.ID, + PolicyID: pm.PolicyID, + Name: pm.Name, + Description: pm.Description, + Enabled: pm.Enabled, + Action: pm.Action, + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), + Bidirectional: pm.Bidirectional, + Protocol: pm.Protocol, + Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), + } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) + return rule +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go new file mode 100644 index 000000000..820872f20 --- /dev/null +++ b/management/server/types/resource.go @@ -0,0 +1,30 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Resource struct { + ID string + Type string +} + +func (r *Resource) ToAPIResponse() *api.Resource { + if r.ID == "" && r.Type == "" { + return nil + } + + return &api.Resource{ + Id: r.ID, + Type: api.ResourceType(r.Type), + } +} + +func (r *Resource) FromAPIRequest(req *api.Resource) { + if req == nil { + return + } + + r.ID = req.Id + r.Type = string(req.Type) +} diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go new file mode 100644 index 000000000..64708d68a --- /dev/null +++ b/management/server/types/route_firewall_rule.go @@ -0,0 +1,32 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/domain" +) + +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // Domains list of network domains for the routed traffic + Domains domain.List + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} diff --git a/management/server/types/settings.go b/management/server/types/settings.go new file mode 100644 index 000000000..0ce5a6133 --- /dev/null +++ b/management/server/types/settings.go @@ -0,0 +1,68 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/management/server/account" +) + +// Settings represents Account settings structure that can be modified via API and Dashboard +type Settings struct { + // PeerLoginExpirationEnabled globally enables or disables peer login expiration + PeerLoginExpirationEnabled bool + + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration + + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer + GroupsPropagationEnabled bool + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string + + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + + // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers + RoutingPeerDNSResolutionEnabled bool + + // Extra is a dictionary of Account settings + Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` +} + +// Copy copies the Settings struct +func (s *Settings) Copy() *Settings { + settings := &Settings{ + PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, + PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, + GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + } + if s.Extra != nil { + settings.Extra = s.Extra.Copy() + } + return settings +} diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go new file mode 100644 index 000000000..2cd835289 --- /dev/null +++ b/management/server/types/setupkey.go @@ -0,0 +1,198 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/fnv" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/util" +) + +const ( + // SetupKeyReusable is a multi-use key (can be used for multiple machines) + SetupKeyReusable SetupKeyType = "reusable" + // SetupKeyOneOff is a single use key (can be used only once) + SetupKeyOneOff SetupKeyType = "one-off" + // DefaultSetupKeyDuration = 1 month + DefaultSetupKeyDuration = 24 * 30 * time.Hour + // DefaultSetupKeyName is a default name of the default setup key + DefaultSetupKeyName = "Default key" + // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key + SetupKeyUnlimitedUsage = 0 +) + +// SetupKeyType is the type of setup key +type SetupKeyType string + +// SetupKey represents a pre-authorized key used to register machines (peers) +type SetupKey struct { + Id string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Key string + KeySecret string + Name string + Type SetupKeyType + CreatedAt time.Time + ExpiresAt *time.Time + UpdatedAt time.Time `gorm:"autoUpdateTime:false"` + // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) + Revoked bool + // UsedTimes indicates how many times the key was used + UsedTimes int + // LastUsed last time the key was used for peer registration + LastUsed *time.Time + // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register + AutoGroups []string `gorm:"serializer:json"` + // UsageLimit indicates the number of times this key can be used to enroll a machine. + // The value of 0 indicates the unlimited usage. + UsageLimit int + // Ephemeral indicate if the peers will be ephemeral or not + Ephemeral bool +} + +// Copy copies SetupKey to a new object +func (key *SetupKey) Copy() *SetupKey { + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + return &SetupKey{ + Id: key.Id, + AccountID: key.AccountID, + Key: key.Key, + KeySecret: key.KeySecret, + Name: key.Name, + Type: key.Type, + CreatedAt: key.CreatedAt, + ExpiresAt: key.ExpiresAt, + UpdatedAt: key.UpdatedAt, + Revoked: key.Revoked, + UsedTimes: key.UsedTimes, + LastUsed: key.LastUsed, + AutoGroups: autoGroups, + UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, + } +} + +// EventMeta returns activity event meta related to the setup key +func (key *SetupKey) EventMeta() map[string]any { + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} +} + +// GetLastUsed returns the last used time of the setup key. +func (key *SetupKey) GetLastUsed() time.Time { + if key.LastUsed != nil { + return *key.LastUsed + } + return time.Time{} +} + +// GetExpiresAt returns the expiration time of the setup key. +func (key *SetupKey) GetExpiresAt() time.Time { + if key.ExpiresAt != nil { + return *key.ExpiresAt + } + return time.Time{} +} + +// HiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func HiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} + +// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now +func (key *SetupKey) IncrementUsage() *SetupKey { + c := key.Copy() + c.UsedTimes++ + c.LastUsed = util.ToPtr(time.Now().UTC()) + return c +} + +// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to +func (key *SetupKey) IsValid() bool { + return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() +} + +// IsRevoked if key was revoked +func (key *SetupKey) IsRevoked() bool { + return key.Revoked +} + +// IsExpired if key was expired +func (key *SetupKey) IsExpired() bool { + if key.GetExpiresAt().IsZero() { + return false + } + return time.Now().After(key.GetExpiresAt()) +} + +// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. +func (key *SetupKey) IsOverUsed() bool { + limit := key.UsageLimit + if key.Type == SetupKeyOneOff { + limit = 1 + } + return limit > 0 && key.UsedTimes >= limit +} + +// GenerateSetupKey generates a new setup key +func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, + usageLimit int, ephemeral bool) (*SetupKey, string) { + key := strings.ToUpper(uuid.New().String()) + limit := usageLimit + if t == SetupKeyOneOff { + limit = 1 + } + + var expiresAt *time.Time + if validFor != 0 { + expiresAt = util.ToPtr(time.Now().UTC().Add(validFor)) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + return &SetupKey{ + Id: strconv.Itoa(int(Hash(key))), + Key: encodedHashedKey, + KeySecret: HiddenKey(key, 4), + Name: name, + Type: t, + CreatedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + UpdatedAt: time.Now().UTC(), + Revoked: false, + UsedTimes: 0, + AutoGroups: autoGroups, + UsageLimit: limit, + Ephemeral: ephemeral, + }, key +} + +// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration +func GenerateDefaultSetupKey() (*SetupKey, string) { + return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, + SetupKeyUnlimitedUsage, false) +} + +func Hash(s string) uint32 { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + panic(err) + } + return h.Sum32() +} diff --git a/management/server/types/user.go b/management/server/types/user.go new file mode 100644 index 000000000..348fbfb22 --- /dev/null +++ b/management/server/types/user.go @@ -0,0 +1,239 @@ +package types + +import ( + "fmt" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" +) + +const ( + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" +) + +// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown +func StrRoleToUserRole(strRole string) UserRole { + switch strings.ToLower(strRole) { + case "owner": + return UserRoleOwner + case "admin": + return UserRoleAdmin + case "user": + return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin + default: + return UserRoleUnknown + } +} + +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User +type UserRole string + +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` +} + +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + +// User represents a user of the system +type User struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin *time.Time + // CreatedAt records the time the user was created + CreatedAt time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +func (u *User) LastDashboardLoginChanged(lastLogin time.Time) bool { + return lastLogin.After(u.GetLastLogin()) && !u.GetLastLogin().IsZero() +} + +// GetLastLogin returns the last login time of the user. +func (u *User) GetLastLogin() time.Time { + if u.LastLogin != nil { + return *u.LastLogin + } + return time.Time{} +} + +// HasAdminPower returns true if the user has admin or owner roles, false otherwise +func (u *User) HasAdminPower() bool { + return u.Role == UserRoleAdmin || u.Role == UserRoleOwner +} + +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { + autoGroups := u.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + + if userData == nil { + return &UserInfo{ + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil + } + if userData.ID != u.Id { + return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) + } + + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + + return &UserInfo{ + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil +} + +// Copy the user +func (u *User) Copy() *User { + autoGroups := make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + pats[k] = v.Copy() + } + return &User{ + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + NonDeletable: u.NonDeletable, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + CreatedAt: u.CreatedAt, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, + } +} + +// NewUser creates a new user +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { + return &User{ + Id: id, + Role: role, + IsServiceUser: isServiceUser, + NonDeletable: nonDeletable, + ServiceUserName: serviceUserName, + AutoGroups: autoGroups, + Issued: issued, + CreatedAt: time.Now().UTC(), + } +} + +// NewRegularUser creates a new user with role UserRoleUser +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +} + +// NewAdminUser creates a new user with role UserRoleAdmin +func NewAdminUser(id string) *User { + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +} + +// NewOwnerUser creates a new user with role UserRoleOwner +func NewOwnerUser(id string) *User { + return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index d338b84b1..de7dd57df 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -9,13 +9,14 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const channelBufferSize = 100 type UpdateMessage struct { Update *proto.SyncResponse - NetworkMap *NetworkMap + NetworkMap *types.NetworkMap } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index 1639ec50f..22cd785eb 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -10,222 +10,22 @@ import ( "github.com/google/uuid" "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" log "github.com/sirupsen/logrus" ) -const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" - UserRoleBillingAdmin UserRole = "billing_admin" - - UserStatusActive UserStatus = "active" - UserStatusDisabled UserStatus = "disabled" - UserStatusInvited UserStatus = "invited" - - UserIssuedAPI = "api" - UserIssuedIntegration = "integration" -) - -// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown -func StrRoleToUserRole(strRole string) UserRole { - switch strings.ToLower(strRole) { - case "owner": - return UserRoleOwner - case "admin": - return UserRoleAdmin - case "user": - return UserRoleUser - case "billing_admin": - return UserRoleBillingAdmin - default: - return UserRoleUnknown - } -} - -// UserStatus is the status of a User -type UserStatus string - -// UserRole is the role of a User -type UserRole string - -// User represents a user of the system -type User struct { - Id string `gorm:"primaryKey"` - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Role UserRole - IsServiceUser bool - // NonDeletable indicates whether the service user can be deleted - NonDeletable bool - // ServiceUserName is only set if IsServiceUser is true - ServiceUserName string - // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string `gorm:"serializer:json"` - PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` - // Blocked indicates whether the user is blocked. Blocked users can't use the system. - Blocked bool - // LastLogin is the last time the user logged in to IdP - LastLogin time.Time - // CreatedAt records the time the user was created - CreatedAt time.Time - - // Issued of the user - Issued string `gorm:"default:api"` - - IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// IsBlocked returns true if the user is blocked, false otherwise -func (u *User) IsBlocked() bool { - return u.Blocked -} - -func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { - return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() -} - -// HasAdminPower returns true if the user has admin or owner roles, false otherwise -func (u *User) HasAdminPower() bool { - return u.Role == UserRoleAdmin || u.Role == UserRoleOwner -} - -// IsAdminOrServiceUser checks if the user has admin power or is a service user. -func (u *User) IsAdminOrServiceUser() bool { - return u.HasAdminPower() || u.IsServiceUser -} - -// IsRegularUser checks if the user is a regular user. -func (u *User) IsRegularUser() bool { - return !u.HasAdminPower() && !u.IsServiceUser -} - -// ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { - autoGroups := u.AutoGroups - if autoGroups == nil { - autoGroups = []string{} - } - - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - - if userData == nil { - return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil - } - if userData.ID != u.Id { - return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) - } - - userStatus := UserStatusActive - if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { - userStatus = UserStatusInvited - } - - return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil -} - -// Copy the user -func (u *User) Copy() *User { - autoGroups := make([]string, len(u.AutoGroups)) - copy(autoGroups, u.AutoGroups) - pats := make(map[string]*PersonalAccessToken, len(u.PATs)) - for k, v := range u.PATs { - pats[k] = v.Copy() - } - return &User{ - Id: u.Id, - AccountID: u.AccountID, - Role: u.Role, - AutoGroups: autoGroups, - IsServiceUser: u.IsServiceUser, - NonDeletable: u.NonDeletable, - ServiceUserName: u.ServiceUserName, - PATs: pats, - Blocked: u.Blocked, - LastLogin: u.LastLogin, - CreatedAt: u.CreatedAt, - Issued: u.Issued, - IntegrationReference: u.IntegrationReference, - } -} - -// NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { - return &User{ - Id: id, - Role: role, - IsServiceUser: isServiceUser, - NonDeletable: nonDeletable, - ServiceUserName: serviceUserName, - AutoGroups: autoGroups, - Issued: issued, - CreatedAt: time.Now().UTC(), - } -} - -// NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) -} - -// NewAdminUser creates a new user with role UserRoleAdmin -func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) -} - -// NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) -} - // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } @@ -238,45 +38,45 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI return nil, status.NewAdminPermissionError() } - if role == UserRoleOwner { + if role == types.UserRoleOwner { return nil, status.NewServiceUserRoleInvalidError() } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) + newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) - if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil { + if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { return nil, err } meta := map[string]any{"name": newUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) - return &UserInfo{ + return &types.UserInfo{ ID: newUser.Id, Email: "", Name: newUser.ServiceUserName, Role: string(newUser.Role), AutoGroups: newUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: true, LastLogin: time.Time{}, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nil } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -288,7 +88,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -299,7 +99,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u inviterID := userID if initiatorUser.IsServiceUser { - createdBy, err := am.Store.GetAccountCreatedBy(ctx, LockingStrengthShare, accountID) + createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -336,7 +136,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - newUser := &User{ + newUser := &types.User{ Id: idpUser.ID, AccountID: accountID, Role: StrRoleToUserRole(invite.Role), @@ -346,12 +146,12 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u CreatedAt: time.Now().UTC(), } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil { + if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { return nil, err } @@ -365,19 +165,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, settings) } -func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { - return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { + return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -401,14 +201,14 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *User) error { - if err := am.Store.DeleteUser(ctx, LockingStrengthUpdate, accountID, targetUser.Id); err != nil { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { + if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil { return err } meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} @@ -425,7 +225,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } @@ -438,17 +238,17 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.NewAdminPermissionError() } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return err } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { return status.NewOwnerDeletePermissionError() } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !initiatorUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !initiatorUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only integration service user can delete this user") } @@ -467,7 +267,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -482,7 +282,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } @@ -518,7 +318,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -530,7 +330,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } @@ -539,7 +339,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.NewUserNotPartOfAccountError() } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -548,12 +348,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.NewAdminPermissionError() } - pat, err := CreateNewPAT(tokenName, expiresIn, targetUserID, initiatorUser.Id) + pat, err := types.CreateNewPAT(tokenName, expiresIn, targetUserID, initiatorUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { + if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { return nil, err } @@ -568,7 +368,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } @@ -581,17 +381,17 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return status.NewAdminPermissionError() } - pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID) + pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) if err != nil { return err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return err } - if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, targetUserID, tokenID); err != nil { + if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil { return err } @@ -602,8 +402,8 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } @@ -616,12 +416,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, status.NewAdminPermissionError() } - return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID) + return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } @@ -634,21 +434,21 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.NewAdminPermissionError() } - return am.Store.GetUserPATs(ctx, LockingStrengthShare, targetUserID) + return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) { return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err } @@ -663,12 +463,12 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i // SaveOrAddUsers updates existing users or adds new users to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if len(updates) == 0 { return nil, nil //nolint:nilnil } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } @@ -681,7 +481,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.NewAdminPermissionError() } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -692,7 +492,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, var usersToSave = make([]*User, 0, len(updates)) var updatedUsersInfo = make([]*UserInfo, 0, len(updates)) - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, fmt.Errorf("error getting account groups: %w", err) } @@ -729,7 +529,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updatedUsersInfo = append(updatedUsersInfo, updatedUserInfo) } - return transaction.SaveUsers(ctx, LockingStrengthUpdate, usersToSave) + return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave) }) if err != nil { return nil, err @@ -747,7 +547,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if settings.GroupsPropagationEnabled && updateAccountPeers { - if err = am.Store.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } am.updateAccountPeers(ctx, accountID) @@ -757,7 +557,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*nbgroup.Group, accountID string, initiatorUserID string, oldUser, newUser *User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*types.Group, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -783,9 +583,14 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr }) } + return eventsToStore +} + +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() if newUser.AutoGroups != nil { - removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + removedGroups := utildifference(oldUser.AutoGroups, newUser.AutoGroups) + addedGroups := utildifference(newUser.AutoGroups, oldUser.AutoGroups) for _, g := range removedGroups { group, ok := groupsMap[g] if ok { @@ -807,12 +612,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr } } } - return eventsToStore } -func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*nbgroup.Group, - initiatorUser, update *User, addIfNotExists bool, settings *Settings) (bool, *User, []*nbpeer.Peer, []func(), error) { +func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*types.Group, + initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { if update == nil { return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") @@ -842,7 +646,7 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti return false, nil, nil, nil, err } - userPeers, err := transaction.GetUserPeers(ctx, LockingStrengthUpdate, updatedUser.AccountID, update.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) if err != nil { return false, nil, nil, nil, err } @@ -854,13 +658,13 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti } if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + removedGroups := util.difference(oldUser.AutoGroups, update.AutoGroups) updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) if err != nil { return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) } } @@ -873,7 +677,7 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) { - existingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, update.Id) + existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { @@ -886,12 +690,12 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update * return existingUser, nil } -func handleOwnerRoleTransfer(ctx context.Context, transaction Store, initiatorUser, update *User) (bool, error) { +func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *User) (bool, error) { if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin + newInitiatorUser.Role = types.UserRoleAdmin - if err := transaction.SaveUser(ctx, LockingStrengthUpdate, newInitiatorUser); err != nil { + if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil { return false, err } return true, nil @@ -902,8 +706,8 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction Store, initiatorUs // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction Store, user *User, accountID string) (*UserInfo, error) { - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction store.Store, user *types.User, accountID string) (*types.UserInfo, error) { + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -919,23 +723,23 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction St } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } - if oldUser.IsServiceUser && update.Role == UserRoleOwner { + if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "can't update a service user with owner role") } @@ -954,7 +758,7 @@ func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldU } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -981,7 +785,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u userObj := account.Users[userID] - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { + if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner { account.Domain = lowerDomain err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -994,13 +798,13 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) { + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -1026,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } if !user.IsServiceUser { - users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) + users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) } } queriedUsers, err = am.lookupCache(ctx, users, accountID) @@ -1038,9 +842,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - userInfos := make([]*UserInfo, 0) + userInfos := make([]*types.UserInfo, 0) - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1067,7 +871,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } - var info *UserInfo + var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { info, err = localUser.ToUserInfo(queriedUser, settings) if err != nil { @@ -1087,16 +891,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } } - info = &UserInfo{ + info = &types.UserInfo{ ID: localUser.Id, Email: "", Name: name, Role: string(localUser.Role), AutoGroups: localUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, + Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) @@ -1118,7 +922,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent( @@ -1131,7 +935,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } @@ -1177,7 +981,7 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context // If an error occurs while deleting the user, the function skips it and continues deleting other users. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } @@ -1195,19 +999,19 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { allErrors = errors.Join(allErrors, err) continue } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID)) continue } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !initiatorUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !initiatorUser.IsServiceUser { allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) continue } @@ -1224,7 +1028,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return allErrors @@ -1257,13 +1061,13 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var updateAccountPeers bool var targetUser *User - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - targetUser, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, targetUserID) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1276,7 +1080,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI } } - if err = transaction.DeleteUser(ctx, LockingStrengthUpdate, accountID, targetUserID); err != nil { + if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserID); err != nil { return fmt.Errorf("failed to delete user: %s %w", targetUserID, err) } @@ -1296,8 +1100,8 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return @@ -1330,7 +1134,7 @@ func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[strin } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -1347,7 +1151,7 @@ func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { // skip removing peers from group All if group.Name == "All" { return @@ -1372,19 +1176,19 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa return nil, false } -func validateUserInvite(invite *UserInfo) error { +func validateUserInvite(invite *types.UserInfo) error { if invite == nil { return fmt.Errorf("provided user update is nil") } - invitedRole := StrRoleToUserRole(invite.Role) + invitedRole := types.StrRoleToUserRole(invite.Role) switch { case invite.Name == "": return status.Errorf(status.InvalidArgument, "name can't be empty") case invite.Email == "": return status.Errorf(status.InvalidArgument, "email can't be empty") - case invitedRole == UserRoleOwner: + case invitedRole == types.UserRoleOwner: return status.Errorf(status.InvalidArgument, "can't invite a user with owner role") default: } diff --git a/management/server/user_test.go b/management/server/user_test.go index 2f8c1bf70..e9889e56b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,8 +10,12 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/util" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,11 +45,15 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -82,14 +90,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -104,14 +116,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -130,11 +146,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { } func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -149,11 +169,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -168,19 +192,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } func TestUser_DeletePAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -204,20 +232,24 @@ func TestUser_DeletePAT(t *testing.T) { } func TestUser_GetPAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -237,13 +269,17 @@ func TestUser_GetPAT(t *testing.T) { } func TestUser_GetAllPATs(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, @@ -254,7 +290,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -274,26 +310,26 @@ func TestUser_GetAllPATs(t *testing.T) { func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way - user := User{ + user := types.User{ Id: "userId", AccountID: "accountId", Role: "role", IsServiceUser: true, ServiceUserName: "servicename", AutoGroups: []string{"group1", "group2"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", - ExpirationDate: time.Now().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().AddDate(0, 0, 7)), CreatedBy: "userId", CreatedAt: time.Now(), - LastUsed: time.Now(), + LastUsed: util.ToPtr(time.Now()), }, }, Blocked: false, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), CreatedAt: time.Now().UTC(), Issued: "test", IntegrationReference: integration_reference.IntegrationReference{ @@ -340,11 +376,15 @@ func validateStruct(s interface{}) (err error) { } func TestUser_CreateServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -366,26 +406,30 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.NotNil(t, account.Users[user.ID]) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) + assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } } func TestUser_CreateUser_ServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -395,7 +439,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -413,7 +457,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { assert.Equal(t, 2, len(account.Users)) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) @@ -423,11 +467,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { } func TestUser_CreateUser_RegularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -437,7 +485,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -448,11 +496,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { } func TestUser_InviteNewUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -495,7 +547,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -506,9 +558,9 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, - Role: string(UserRoleOwner), + Role: string(types.UserRoleOwner), Email: "test2@teste.com", IsServiceUser: false, AutoGroups: []string{"group1", "group2"}, @@ -520,13 +572,13 @@ func TestUser_InviteNewUser(t *testing.T) { func TestUser_DeleteUser_ServiceUser(t *testing.T) { tests := []struct { name string - serviceUser *User + serviceUser *types.User assertErrFunc assert.ErrorAssertionFunc assertErrMessage string }{ { name: "Can delete service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -535,7 +587,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { }, { name: "Cannot delete non-deletable service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -548,11 +600,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -580,11 +637,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } func TestUser_DeleteUser_SelfDelete(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -601,38 +662,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { } func TestUser_DeleteUser_regularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -683,60 +748,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } func TestUser_DeleteUser_RegularUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - account.Users["user6"] = &User{ + account.Users["user6"] = &types.User{ Id: "user6", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user7"] = &User{ + account.Users["user7"] = &types.User{ Id: "user7", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user8"] = &User{ + account.Users["user8"] = &types.User{ Id: "user8", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - account.Users["user9"] = &User{ + account.Users["user9"] = &types.User{ Id: "user9", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -834,11 +903,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { } func TestDefaultAccountManager_GetUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -863,13 +936,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { } func TestDefaultAccountManager_ListUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) - err := store.SaveAccount(context.Background(), account) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") + account.Users["normal_user1"] = types.NewRegularUser("normal_user1") + account.Users["normal_user2"] = types.NewRegularUser("normal_user2") + + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -901,43 +978,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool expectedDashboardPermissions string }{ { name: "Regular user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, expectedDashboardPermissions: "limited", }, { name: "Admin user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, expectedDashboardPermissions: "blocked", }, { name: "Admin user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, expectedDashboardPermissions: "full", }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, expectedDashboardPermissions: "full", }, @@ -945,13 +1022,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -976,13 +1058,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { } func TestDefaultAccountManager_ExternalCache(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ + externalUser := &types.User{ Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + Role: types.UserRoleUser, + Issued: types.UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", @@ -990,7 +1076,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1020,7 +1106,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) - var user *UserInfo + var user *types.UserInfo for _, info := range infos { if info.ID == externalUser.Id { user = info @@ -1032,24 +1118,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestUser_IsAdmin(t *testing.T) { - user := NewAdminUser(mockUserID) + user := types.NewAdminUser(mockUserID) assert.True(t, user.HasAdminPower()) - user = NewRegularUser(mockUserID) + user = types.NewRegularUser(mockUserID) assert.False(t, user.HasAdminPower()) } func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1068,17 +1158,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { } func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1112,25 +1205,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { tt := []struct { name string initiatorID string - update *User + update *types.User expectedErr bool }{ { name: "Should_Fail_To_Update_Admin_Role", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, Blocked: false, }, }, { name: "Should_Fail_When_Admin_Blocks_Themselves", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1138,9 +1231,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Non_Existing_User", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: userID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1148,9 +1241,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1158,9 +1251,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Update_User", expectedErr: false, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1168,9 +1261,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Transfer_Owner_Role_To_User", expectedErr: false, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1178,9 +1271,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User", expectedErr: true, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: serviceUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1188,9 +1281,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1198,9 +1291,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_User", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1208,9 +1301,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User", expectedErr: true, initiatorID: serviceUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1218,9 +1311,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1228,9 +1321,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Block_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: true, }, }, @@ -1246,9 +1339,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} + account.Users[regularUserID] = types.NewRegularUser(regularUserID) + account.Users[adminUserID] = types.NewAdminUser(adminUserID) + account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) @@ -1272,22 +1365,22 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) require.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1307,11 +1400,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1330,11 +1423,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) @@ -1364,11 +1457,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) // create a user and add new peer with the user - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1390,11 +1483,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) diff --git a/management/server/users/manager.go b/management/server/users/manager.go new file mode 100644 index 000000000..718eb6190 --- /dev/null +++ b/management/server/users/manager.go @@ -0,0 +1,49 @@ +package users + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetUser(ctx context.Context, userID string) (*types.User, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { + return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User, error) { + switch userID { + case "adminUser": + return &types.User{Id: userID, Role: types.UserRoleAdmin}, nil + case "regularUser": + return &types.User{Id: userID, Role: types.UserRoleUser}, nil + case "ownerUser": + return &types.User{Id: userID, Role: types.UserRoleOwner}, nil + case "billingUser": + return &types.User{Id: userID, Role: types.UserRoleBillingAdmin}, nil + default: + return nil, errors.New("user not found") + } +} diff --git a/management/server/util/util.go b/management/server/util/util.go new file mode 100644 index 000000000..d85b55f02 --- /dev/null +++ b/management/server/util/util.go @@ -0,0 +1,21 @@ +package util + +// Difference returns the elements in `a` that aren't in `b`. +func Difference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} + +// ToPtr returns a pointer to the given value. +func ToPtr[T any](value T) *T { + return &value +} diff --git a/relay/client/client.go b/relay/client/client.go index 154c1787f..db5252f50 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -140,7 +140,7 @@ type Client struct { instanceURL *RelayAddr muInstanceURL sync.Mutex - onDisconnectListener func() + onDisconnectListener func(string) onConnectedListener func() listenerMutex sync.Mutex } @@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) { } // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. -func (c *Client) SetOnDisconnectListener(fn func()) { +func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() defer c.listenerMutex.Unlock() c.onDisconnectListener = fn @@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() { if c.onDisconnectListener == nil { return } - go c.onDisconnectListener() + go c.onDisconnectListener(c.connectionURL) } func (c *Client) notifyConnected() { diff --git a/relay/client/client_test.go b/relay/client/client_test.go index ef28203e9..7ddfba4c6 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) { } disconnected := make(chan struct{}) - relayClient.SetOnDisconnectListener(func() { + relayClient.SetOnDisconnectListener(func(_ string) { log.Infof("client disconnected") close(disconnected) }) diff --git a/relay/client/guard.go b/relay/client/guard.go index d6b6b0da5..b971363a8 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -4,65 +4,120 @@ import ( "context" "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" ) var ( - reconnectingTimeout = 5 * time.Second + reconnectingTimeout = 60 * time.Second ) // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - ctx context.Context - relayClient *Client + // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + OnNewRelayClient chan *Client + serverPicker *ServerPicker } // NewGuard creates a new guard for the relay client. -func NewGuard(context context.Context, relayClient *Client) *Guard { +func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ - ctx: context, - relayClient: relayClient, + OnNewRelayClient: make(chan *Client, 1), + serverPicker: sp, } return g } -// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection +// StartReconnectTrys is called when the relay client is disconnected from the relay server. +// It attempts to reconnect to the relay server. The function first tries a quick reconnect +// to the same server that was used before, if the server URL is still valid. If the quick +// reconnect fails, it starts a ticker to periodically attempt server picking until it +// succeeds or the context is done. +// +// Parameters: +// - ctx: The context to control the lifecycle of the reconnection attempts. +// - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent -func (g *Guard) OnDisconnected() { - if g.quickReconnect() { +func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { + if relayClient == nil { + goto RETRY + } + if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { return } - ticker := time.NewTicker(reconnectingTimeout) +RETRY: + ticker := exponentTicker(ctx) defer ticker.Stop() for { select { case <-ticker.C: - err := g.relayClient.Connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) + if err := g.retry(ctx); err != nil { + log.Errorf("failed to pick new Relay server: %s", err) continue } return - case <-g.ctx.Done(): + case <-ctx.Done(): return } } } -func (g *Guard) quickReconnect() bool { - ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) +func (g *Guard) retry(ctx context.Context) error { + log.Infof("try to pick up a new Relay server") + relayClient, err := g.serverPicker.PickServer(ctx) + if err != nil { + return err + } + + // prevent to work with a deprecated Relay client instance + g.drainRelayClientChan() + + g.OnNewRelayClient <- relayClient + return nil +} + +func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { + ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) defer cancel() <-ctx.Done() - if g.ctx.Err() != nil { + if parentCtx.Err() != nil { return false } + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := g.relayClient.Connect(); err != nil { + if err := rc.Connect(); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } return true } + +func (g *Guard) drainRelayClientChan() { + select { + case <-g.OnNewRelayClient: + default: + } +} + +func (g *Guard) isServerURLStillValid(rc *Client) bool { + for _, url := range g.serverPicker.ServerURLs.Load().([]string) { + if url == rc.connectionURL { + return true + } + } + return false +} + +func exponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 2 * time.Second, + Multiplier: 2, + MaxInterval: reconnectingTimeout, + Clock: backoff.SystemClock, + }, ctx) + + return backoff.NewTicker(bo) +} diff --git a/relay/client/manager.go b/relay/client/manager.go index b14a7701b..d847bb879 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -57,12 +57,15 @@ type ManagerService interface { // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // unused relay connection and close it. type Manager struct { - ctx context.Context - serverURLs []string - peerID string - tokenStore *relayAuth.TokenStore + ctx context.Context + peerID string + running bool + tokenStore *relayAuth.TokenStore + serverPicker *ServerPicker - relayClient *Client + relayClient *Client + // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable + relayClientMu sync.Mutex reconnectGuard *Guard relayClients map[string]*RelayTrack @@ -76,48 +79,54 @@ type Manager struct { // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { - return &Manager{ - ctx: ctx, - serverURLs: serverURLs, - peerID: peerID, - tokenStore: &relayAuth.TokenStore{}, + tokenStore := &relayAuth.TokenStore{} + + m := &Manager{ + ctx: ctx, + peerID: peerID, + tokenStore: tokenStore, + serverPicker: &ServerPicker{ + TokenStore: tokenStore, + PeerID: peerID, + }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), } + m.serverPicker.ServerURLs.Store(serverURLs) + m.reconnectGuard = NewGuard(m.serverPicker) + return m } -// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for -// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. +// Serve starts the manager, attempting to establish a connection with the relay server. +// If the connection fails, it will keep trying to reconnect in the background. +// Additionally, it starts a cleanup loop to remove unused relay connections. +// The manager will automatically reconnect to the relay server in case of disconnection. func (m *Manager) Serve() error { - if m.relayClient != nil { + if m.running { return fmt.Errorf("manager already serving") } - log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) + m.running = true + log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load()) - sp := ServerPicker{ - TokenStore: m.tokenStore, - PeerID: m.peerID, - } - - client, err := sp.PickServer(m.ctx, m.serverURLs) + client, err := m.serverPicker.PickServer(m.ctx) if err != nil { - return err + go m.reconnectGuard.StartReconnectTrys(m.ctx, nil) + } else { + m.storeClient(client) } - m.relayClient = client - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnConnectedListener(m.onServerConnected) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil + go m.listenGuardEvent(m.ctx) + go m.startCleanupLoop() + return err } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return nil, ErrRelayClientNotConnected } @@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { // Ready returns true if the home Relay client is connected to the relay server. func (m *Manager) Ready() bool { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return false } @@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) { // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + if m.relayClient == nil { + return ErrRelayClientNotConnected + } + foreign, err := m.isForeignServer(serverAddress) if err != nil { return err @@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // lost. This address will be sent to the target peer to choose the common relay server for the communication. func (m *Manager) RelayInstanceAddress() (string, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return "", ErrRelayClientNotConnected } @@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) { // ServerURLs returns the addresses of the relay servers. func (m *Manager) ServerURLs() []string { - return m.serverURLs + return m.serverPicker.ServerURLs.Load().([]string) } // HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with // Relay service. func (m *Manager) HasRelayAddress() bool { - return len(m.serverURLs) > 0 + return len(m.serverPicker.ServerURLs.Load().([]string)) > 0 +} + +func (m *Manager) UpdateServerURLs(serverURLs []string) { + log.Infof("update relay server URLs: %v", serverURLs) + m.serverPicker.ServerURLs.Store(serverURLs) } // UpdateToken updates the token in the token store. @@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return nil, err } // if connection closed then delete the relay client from the list - relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(serverAddress) - }) + relayClient.SetOnDisconnectListener(m.onServerDisconnected) rt.relayClient = relayClient rt.Unlock() @@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } +// onServerDisconnected start to reconnection for home server only func (m *Manager) onServerDisconnected(serverAddress string) { + m.relayClientMu.Lock() if serverAddress == m.relayClient.connectionURL { - go m.reconnectGuard.OnDisconnected() + go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) } + m.relayClientMu.Unlock() m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) listenGuardEvent(ctx context.Context) { + for { + select { + case rc := <-m.reconnectGuard.OnNewRelayClient: + m.storeClient(rc) + case <-ctx.Done(): + return + } + } +} + +func (m *Manager) storeClient(client *Client) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + m.relayClient = client + m.relayClient.SetOnConnectedListener(m.onServerConnected) + m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) +} + func (m *Manager) isForeignServer(address string) (bool, error) { rAddr, err := m.relayClient.ServerInstanceURL() if err != nil { @@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - if m.ctx.Err() != nil { - return - } - ticker := time.NewTicker(relayCleanupInterval) - go func() { - defer ticker.Stop() - for { - select { - case <-m.ctx.Done(): - return - case <-ticker.C: - m.cleanUpUnusedRelays() - } + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanUpUnusedRelays() } - }() + } } func (m *Manager) cleanUpUnusedRelays() { diff --git a/relay/client/picker.go b/relay/client/picker.go index 13b0547aa..eb5062dbb 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -12,10 +13,13 @@ import ( ) const ( - connectionTimeout = 30 * time.Second maxConcurrentServers = 7 ) +var ( + connectionTimeout = 30 * time.Second +) + type connResult struct { RelayClient *Client Url string @@ -24,20 +28,22 @@ type connResult struct { type ServerPicker struct { TokenStore *auth.TokenStore + ServerURLs atomic.Value PeerID string } -func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { +func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) defer cancel() - totalServers := len(urls) + totalServers := len(sp.ServerURLs.Load().([]string)) connResultChan := make(chan connResult, totalServers) successChan := make(chan connResult, 1) concurrentLimiter := make(chan struct{}, maxConcurrentServers) - for _, url := range urls { + log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string)) + for _, url := range sp.ServerURLs.Load().([]string) { // todo check if we have a successful connection so we do not need to connect to other servers concurrentLimiter <- struct{}{} go func(url string) { @@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { cr := <-resultChan if cr.Err != nil { - log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) continue } log.Infof("connected to Relay server: %s", cr.Url) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 4800e05ba..28167c5ce 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,19 +4,23 @@ import ( "context" "errors" "testing" + "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { + connectionTimeout = 5 * time.Second + sp := ServerPicker{ TokenStore: nil, PeerID: "test", } + sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { - _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) + _, err := sp.PickServer(ctx) if err == nil { t.Error(err) } diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 1ad57d27a..5c62c0826 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -96,5 +96,5 @@ func remoteAddr(r *http.Request) string { if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" { return r.RemoteAddr } - return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) + return net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) } diff --git a/relay/server/peer.go b/relay/server/peer.go index c909c35d5..f65fb786a 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -16,6 +16,8 @@ import ( const ( bufferSize = 8820 + + errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection @@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) * // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + defer func() { + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + p.log.Errorf(errCloseConn, err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * case messages.MsgTypeClose: p.log.Infof("peer exited gracefully") if err := p.conn.Close(); err != nil { - log.Errorf("failed to close connection to peer: %s", err) + log.Errorf(errCloseConn, err) } default: p.log.Warnf("received unexpected message type: %s", msgType) @@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } - err = p.conn.Close() - if err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + if err := p.conn.Close(); err != nil { + p.log.Errorf(errCloseConn, err) } } @@ -132,7 +139,7 @@ func (p *Peer) Close() { defer p.connMu.Unlock() if err := p.conn.Close(); err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + p.log.Errorf(errCloseConn, err) } } diff --git a/release_files/install.sh b/release_files/install.sh index b0fec2733..bb917c39a 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -239,7 +239,12 @@ install_netbird() { dnf) add_rpm_repo ${SUDO} dnf -y install dnf-plugin-config-manager - ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + if [[ "$(dnf --version | head -n1 | cut -d. -f1)" > "4" ]]; + then + ${SUDO} dnf config-manager addrepo --from-repofile=/etc/yum.repos.d/netbird.repo + else + ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + fi ${SUDO} dnf -y install netbird if ! $SKIP_UI_APP; then diff --git a/route/route.go b/route/route.go index e23801e6e..ad2aaba89 100644 --- a/route/route.go +++ b/route/route.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" "slices" + "strings" log "github.com/sirupsen/logrus" @@ -88,18 +89,19 @@ type Route struct { // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` // Network and Domains are mutually exclusive - Network netip.Prefix `gorm:"serializer:json"` - Domains domain.List `gorm:"serializer:json"` - KeepRoute bool - NetID NetID - Description string - Peer string - PeerGroups []string `gorm:"serializer:json"` - NetworkType NetworkType - Masquerade bool - Metric int - Enabled bool - Groups []string `gorm:"serializer:json"` + Network netip.Prefix `gorm:"serializer:json"` + Domains domain.List `gorm:"serializer:json"` + KeepRoute bool + NetID NetID + Description string + Peer string + PeerID string `gorm:"-"` + PeerGroups []string `gorm:"serializer:json"` + NetworkType NetworkType + Masquerade bool + Metric int + Enabled bool + Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` } @@ -111,19 +113,20 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { route := &Route{ - ID: r.ID, - Description: r.Description, - NetID: r.NetID, - Network: r.Network, - Domains: slices.Clone(r.Domains), - KeepRoute: r.KeepRoute, - NetworkType: r.NetworkType, - Peer: r.Peer, - PeerGroups: slices.Clone(r.PeerGroups), - Metric: r.Metric, - Masquerade: r.Masquerade, - Enabled: r.Enabled, - Groups: slices.Clone(r.Groups), + ID: r.ID, + Description: r.Description, + NetID: r.NetID, + Network: r.Network, + Domains: slices.Clone(r.Domains), + KeepRoute: r.KeepRoute, + NetworkType: r.NetworkType, + Peer: r.Peer, + PeerID: r.PeerID, + PeerGroups: slices.Clone(r.PeerGroups), + Metric: r.Metric, + Masquerade: r.Masquerade, + Enabled: r.Enabled, + Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route @@ -145,11 +148,12 @@ func (r *Route) IsEqual(other *Route) bool { other.KeepRoute == r.KeepRoute && other.NetworkType == r.NetworkType && other.Peer == r.Peer && + other.PeerID == r.PeerID && other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.PeerGroups, other.PeerGroups) && slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } @@ -170,6 +174,11 @@ func (r *Route) GetHAUniqueID() HAUniqueID { return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) } +// GetResourceID returns the Networks Resource ID from a route ID +func (r *Route) GetResourceID() string { + return strings.Split(string(r.ID), ":")[0] +} + // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { prefix, err := netip.ParsePrefix(networkString) diff --git a/util/file.go b/util/file.go index ecaecd222..f7de7ede2 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -14,8 +15,21 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enforce permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory -func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { +func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err @@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // WriteJson writes JSON config object to a file creating parent directories if required // The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { +func WriteJson(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } -func writeJson(file string, obj interface{}, configDir string, configFileName string) error { +func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { + // Check context before expensive operations + if ctx.Err() != nil { + return fmt.Errorf("write json start: %w", ctx.Err()) + } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { + if ctx.Err() != nil { + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err + // Check context again + if ctx.Err() != nil { + return fmt.Errorf("after temp file: %w", ctx.Err()) } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil diff --git a/util/file_test.go b/util/file_test.go index 566d8eda6..f8c9dfabb 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() - err := WriteJson(tmpDir+"/testconfig.json", tt.config) + err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config) require.NoError(t, err) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) { src := tmpDir + "/copytest_src" dst := tmpDir + "/copytest_dst" - err := WriteJson(src, tt.srcContent) + err := WriteJson(context.Background(), src, tt.srcContent) require.NoError(t, err) err = CopyFileContents(src, dst) @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) { _ = os.Remove(cfgFile) }() - err := WriteJson(cfgFile, tt.config) + err := WriteJson(context.Background(), cfgFile, tt.config) require.NoError(t, err) read, err := ReadJson(cfgFile, &TestConfig{}) diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 57ab8fd55..4fbffe342 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,9 @@ package grpc import ( "context" "crypto/tls" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net" "os/user" "runtime" @@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption { if runtime.GOOS == "linux" { currentUser, err := user.Current() if err != nil { - log.Fatalf("failed to get current user: %v", err) + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) } // the custom dialer requires root permissions which are not required for use cases run as non-root if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") dialer := &net.Dialer{} return dialer.DialContext(ctx, "tcp", addr) } } + log.Debug("Using nbnet.NewDialer()") conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) - return nil, err + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil }) diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 000000000..26693f841 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 000000000..595311492 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_ios.go b/util/net/dial_ios.go similarity index 100% rename from util/net/dialer_ios.go rename to util/net/dial_ios.go diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go deleted file mode 100644 index 4cbded536..000000000 --- a/util/net/dialer_android.go +++ /dev/null @@ -1,25 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_dial.go similarity index 62% rename from util/net/dialer_nonios.go rename to util/net/dialer_dial.go index 4032a75c0..1659b6220 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_dial.go @@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { - return nil, fmt.Errorf("dial: %w", err) + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } // Wrap the connection in Conn to handle Close with hooks @@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { @@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r return result.ErrorOrNil() } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_init_android.go b/util/net/dialer_init_android.go new file mode 100644 index 000000000..63b903348 --- /dev/null +++ b/util/net/dialer_init_android.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = ControlProtectSocket +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_init_linux.go similarity index 88% rename from util/net/dialer_linux.go rename to util/net/dialer_init_linux.go index aed5c59a3..d801e6080 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_init_linux.go @@ -7,6 +7,6 @@ import "syscall" // init configures the net.Dialer Control function to set the fwmark on the socket func (d *Dialer) init() { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_init_nonlinux.go similarity index 58% rename from util/net/dialer_nonlinux.go rename to util/net/dialer_init_nonlinux.go index c838441bd..8c57ebbaa 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (d *Dialer) init() { + // implemented on Linux and Android only } diff --git a/util/net/env.go b/util/net/env.go new file mode 100644 index 000000000..099da39b7 --- /dev/null +++ b/util/net/env.go @@ -0,0 +1,29 @@ +package net + +import ( + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" +) + +func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } + return os.Getenv(envDisableCustomRouting) == "true" +} + +func SkipSocketMark() bool { + if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { + log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) + return true + } + return false +} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 000000000..3ae8a9435 --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_ios.go b/util/net/listen_ios.go similarity index 100% rename from util/net/listener_ios.go rename to util/net/listen_ios.go diff --git a/util/net/listener_android.go b/util/net/listener_android.go deleted file mode 100644 index d4167ad53..000000000 --- a/util/net/listener_android.go +++ /dev/null @@ -1,26 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect listener socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/listener_init_android.go b/util/net/listener_init_android.go new file mode 100644 index 000000000..f7bfa1dab --- /dev/null +++ b/util/net/listener_init_android.go @@ -0,0 +1,6 @@ +package net + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = ControlProtectSocket +} diff --git a/util/net/listener_linux.go b/util/net/listener_init_linux.go similarity index 89% rename from util/net/listener_linux.go rename to util/net/listener_init_linux.go index 8d332160a..e32d5d894 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_init_linux.go @@ -9,6 +9,6 @@ import ( // init configures the net.ListenerConfig Control function to set the fwmark on the socket func (l *ListenerConfig) init() { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_init_nonlinux.go similarity index 61% rename from util/net/listener_nonlinux.go rename to util/net/listener_init_nonlinux.go index 14a6be49d..80f6f7f1a 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (l *ListenerConfig) init() { + // implemented on Linux and Android only } diff --git a/util/net/listener_nonios.go b/util/net/listener_listen.go similarity index 84% rename from util/net/listener_nonios.go rename to util/net/listener_listen.go index ae4be3494..efffba40e 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_listen.go @@ -8,7 +8,6 @@ import ( "net" "sync" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { return err } - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/net.go b/util/net/net.go index 5448eb85a..403aa87e7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -2,9 +2,6 @@ package net import ( "net" - "os" - - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) @@ -16,8 +13,6 @@ const ( PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasqueradeReturn = 0x1BD12 - - envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } - -func CustomRoutingDisabled() bool { - if netstack.IsEnabled() { - return true - } - return os.Getenv(envDisableCustomRouting) == "true" -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 954545eb5..fc486ebd4 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -5,23 +5,41 @@ package net import ( "fmt" "syscall" + + log "github.com/sirupsen/logrus" ) // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { + if isSocketMarkDisabled() { + return nil + } + sysconn, err := conn.SyscallConn() if err != nil { return fmt.Errorf("get raw conn: %w", err) } - return SetRawSocketMark(sysconn) + return setRawSocketMark(sysconn) } -func SetRawSocketMark(conn syscall.RawConn) error { +// SetSocketOpt sets the SO_MARK option on the given file descriptor +func SetSocketOpt(fd int) error { + if isSocketMarkDisabled() { + return nil + } + + return setSocketOptInt(fd) +} + +func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = SetSocketOpt(int(fd)) + if isSocketMarkDisabled() { + return + } + setErr = setSocketOptInt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -34,10 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } -func SetSocketOpt(fd int) error { - if CustomRoutingDisabled() { - return nil - } - +func setSocketOptInt(fd int) error { return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) } + +func isSocketMarkDisabled() bool { + if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return true + } + + if SkipSocketMark() { + return true + } + return false +} diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go index 64fb45aa4..febed8a1e 100644 --- a/util/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -1,14 +1,42 @@ package net -import "sync" +import ( + "fmt" + "sync" + "syscall" +) var ( androidProtectSocketLock sync.Mutex androidProtectSocket func(fd int32) bool ) -func SetAndroidProtectSocketFn(f func(fd int32) bool) { +func SetAndroidProtectSocketFn(fn func(fd int32) bool) { androidProtectSocketLock.Lock() - androidProtectSocket = f + androidProtectSocket = fn androidProtectSocketLock.Unlock() } + +// ControlProtectSocket is a Control function that sets the fwmark on the socket +func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + var aErr error + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + defer androidProtectSocketLock.Unlock() + + if androidProtectSocket == nil { + aErr = fmt.Errorf("socket protection function not set") + return + } + + if !androidProtectSocket(int32(fd)) { + aErr = fmt.Errorf("failed to protect socket via Android") + } + }) + + if err != nil { + return err + } + + return aErr +} diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go new file mode 100644 index 000000000..ad74e1bfc --- /dev/null +++ b/util/semaphore-group/semaphore_group.go @@ -0,0 +1,48 @@ +package semaphoregroup + +import ( + "context" + "sync" +) + +// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. +type SemaphoreGroup struct { + waitGroup sync.WaitGroup + semaphore chan struct{} +} + +// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit. +func NewSemaphoreGroup(limit int) *SemaphoreGroup { + return &SemaphoreGroup{ + semaphore: make(chan struct{}, limit), + } +} + +// Add increments the internal WaitGroup counter and acquires a semaphore slot. +func (sg *SemaphoreGroup) Add(ctx context.Context) { + sg.waitGroup.Add(1) + + // Acquire semaphore slot + select { + case <-ctx.Done(): + return + case sg.semaphore <- struct{}{}: + } +} + +// Done decrements the internal WaitGroup counter and releases a semaphore slot. +func (sg *SemaphoreGroup) Done(ctx context.Context) { + sg.waitGroup.Done() + + // Release semaphore slot + select { + case <-ctx.Done(): + return + case <-sg.semaphore: + } +} + +// Wait waits until the internal WaitGroup counter is zero. +func (sg *SemaphoreGroup) Wait() { + sg.waitGroup.Wait() +} diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go new file mode 100644 index 000000000..d4491cf77 --- /dev/null +++ b/util/semaphore-group/semaphore_group_test.go @@ -0,0 +1,66 @@ +package semaphoregroup + +import ( + "context" + "testing" + "time" +) + +func TestSemaphoreGroup(t *testing.T) { + semGroup := NewSemaphoreGroup(2) + + for i := 0; i < 5; i++ { + semGroup.Add(context.Background()) + go func(id int) { + defer semGroup.Done(context.Background()) + + got := len(semGroup.semaphore) + if got == 0 { + t.Errorf("Expected semaphore length > 0 , got 0") + } + + time.Sleep(time.Millisecond) + t.Logf("Goroutine %d is running\n", id) + }(i) + } + + semGroup.Wait() + + want := 0 + got := len(semGroup.semaphore) + if got != want { + t.Errorf("Expected semaphore length %d, got %d", want, got) + } +} + +func TestSemaphoreGroupContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + semGroup.Add(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + rChan := make(chan struct{}) + + go func() { + semGroup.Add(ctx) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Adding to semaphore group should not block when context is not done") + } + + semGroup.Done(context.Background()) + + ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancelDone) + go func() { + semGroup.Done(ctxDone) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Releasing from semaphore group should not block when context is not done") + } +}