mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-31 18:39:31 +01:00
Merge branch 'users-get-account-refactoring' into routes-get-account-refactoring
# Conflicts: # management/server/group.go # management/server/route.go # route/route.go
This commit is contained in:
commit
c52e4f9102
@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.21-bullseye
|
FROM golang:1.23-bullseye
|
||||||
|
|
||||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||||
&& apt-get -y install --no-install-recommends\
|
&& apt-get -y install --no-install-recommends\
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||||
"ghcr.io/devcontainers/features/go:1": {
|
"ghcr.io/devcontainers/features/go:1": {
|
||||||
"version": "1.21"
|
"version": "1.23"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||||
|
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@ -21,6 +21,7 @@ jobs:
|
|||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
@ -28,8 +29,9 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
|
macos-gotest-
|
||||||
macos-go-
|
macos-go-
|
||||||
|
|
||||||
- name: Install libpcap
|
- name: Install libpcap
|
||||||
@ -42,4 +44,4 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
360
.github/workflows/golang-test-linux.yml
vendored
360
.github/workflows/golang-test-linux.yml
vendored
@ -11,31 +11,115 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
build-cache:
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
id: cache
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Build client
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: client
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build client 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: client
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
|
||||||
|
|
||||||
|
- name: Build management
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: management
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build management 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: management
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
|
||||||
|
|
||||||
|
- name: Build signal
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: signal
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build signal 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: signal
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
|
||||||
|
|
||||||
|
- name: Build relay
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: relay
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build relay 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: relay
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
needs: [build-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
@ -50,27 +134,265 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
|
test_management:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||||
|
|
||||||
|
api_benchmark:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||||
|
|
||||||
|
api_integration_test:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres']
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
25
.github/workflows/golang-test-windows.yml
vendored
25
.github/workflows/golang-test-windows.yml
vendored
@ -24,6 +24,23 @@ jobs:
|
|||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Download wintun
|
- name: Download wintun
|
||||||
uses: carlosperate/download-file-action@v2
|
uses: carlosperate/download-file-action@v2
|
||||||
@ -42,11 +59,13 @@ jobs:
|
|||||||
- run: choco install -y sysinternals --ignore-checksums
|
- run: choco install -y sysinternals --ignore-checksums
|
||||||
- run: choco install -y mingw
|
- run: choco install -y mingw
|
||||||
|
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
@ -46,7 +46,7 @@ jobs:
|
|||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v4
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
args: --timeout=12m
|
args: --timeout=12m --out-format colored-line-number
|
||||||
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.16"
|
SIGN_PIPE_VER: "v0.0.17"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
|
23
.github/workflows/test-infrastructure-files.yml
vendored
23
.github/workflows/test-infrastructure-files.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||||
@ -34,6 +34,19 @@ jobs:
|
|||||||
--health-timeout 5s
|
--health-timeout 5s
|
||||||
ports:
|
ports:
|
||||||
- 5432:5432
|
- 5432:5432
|
||||||
|
mysql:
|
||||||
|
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
|
||||||
|
env:
|
||||||
|
MYSQL_USER: netbird
|
||||||
|
MYSQL_PASSWORD: mysql
|
||||||
|
MYSQL_ROOT_PASSWORD: mysqlroot
|
||||||
|
MYSQL_DATABASE: netbird
|
||||||
|
options: >-
|
||||||
|
--health-cmd "mysqladmin ping --silent"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
ports:
|
||||||
|
- 3306:3306
|
||||||
steps:
|
steps:
|
||||||
- name: Set Database Connection String
|
- name: Set Database Connection String
|
||||||
run: |
|
run: |
|
||||||
@ -42,6 +55,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
|
if [ "${{ matrix.store }}" == "mysql" ]; then
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
|
||||||
|
else
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Install jq
|
- name: Install jq
|
||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
@ -84,6 +102,7 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
@ -112,6 +131,7 @@ jobs:
|
|||||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
|
|
||||||
@ -149,6 +169,7 @@ jobs:
|
|||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||||
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
|
@ -179,6 +179,51 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
@ -377,6 +422,18 @@ docker_manifests:
|
|||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:rootless-latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
- name_template: netbirdio/relay:{{ .Version }}
|
- name_template: netbirdio/relay:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
15
README.md
15
README.md
@ -1,10 +1,3 @@
|
|||||||
<p align="center">
|
|
||||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
|
||||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
|
||||||
Learn more
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
<br/>
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img width="234" src="docs/media/logo-full.png"/>
|
<img width="234" src="docs/media/logo-full.png"/>
|
||||||
@ -17,8 +10,12 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
|
</a>
|
||||||
|
<br>
|
||||||
|
<a href="https://gurubase.io/g/netbird">
|
||||||
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
@ -30,7 +27,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
FROM alpine:3.20
|
FROM alpine:3.21.0
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
|
16
client/Dockerfile-rootless
Normal file
16
client/Dockerfile-rootless
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
|
COPY netbird /usr/local/bin/netbird
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates \
|
||||||
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
WORKDIR /var/lib/netbird
|
||||||
|
USER netbird:netbird
|
||||||
|
|
||||||
|
ENV NB_FOREGROUND_MODE=true
|
||||||
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_CONFIG=config.json
|
||||||
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
|
ENV NB_DISABLE_DNS=true
|
||||||
|
|
||||||
|
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
@ -12,6 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const anonTLD = ".domain"
|
||||||
|
|
||||||
type Anonymizer struct {
|
type Anonymizer struct {
|
||||||
ipAnonymizer map[netip.Addr]netip.Addr
|
ipAnonymizer map[netip.Addr]netip.Addr
|
||||||
domainAnonymizer map[string]string
|
domainAnonymizer map[string]string
|
||||||
@ -19,6 +21,8 @@ type Anonymizer struct {
|
|||||||
currentAnonIPv6 netip.Addr
|
currentAnonIPv6 netip.Addr
|
||||||
startAnonIPv4 netip.Addr
|
startAnonIPv4 netip.Addr
|
||||||
startAnonIPv6 netip.Addr
|
startAnonIPv6 netip.Addr
|
||||||
|
|
||||||
|
domainKeyRegex *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
@ -34,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
|||||||
currentAnonIPv6: startIPv6,
|
currentAnonIPv6: startIPv6,
|
||||||
startAnonIPv4: startIPv4,
|
startAnonIPv4: startIPv4,
|
||||||
startAnonIPv6: startIPv6,
|
startAnonIPv6: startIPv6,
|
||||||
|
|
||||||
|
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,29 +89,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||||
if strings.HasSuffix(domain, "netbird.io") ||
|
baseDomain := domain
|
||||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
hasDot := strings.HasSuffix(domain, ".")
|
||||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
if hasDot {
|
||||||
strings.HasSuffix(domain, "netbird.stage") ||
|
baseDomain = domain[:len(domain)-1]
|
||||||
strings.HasSuffix(domain, ".domain") {
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(baseDomain, "netbird.io") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.cloud") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.stage") ||
|
||||||
|
strings.HasSuffix(baseDomain, anonTLD) {
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(domain, ".")
|
parts := strings.Split(baseDomain, ".")
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||||
|
|
||||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
anonymized, ok := a.domainAnonymizer[baseForLookup]
|
||||||
if !ok {
|
if !ok {
|
||||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
|
||||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
a.domainAnonymizer[baseForLookup] = anonymizedBase
|
||||||
anonymized = anonymizedBase
|
anonymized = anonymizedBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
|
||||||
|
if hasDot {
|
||||||
|
result += "."
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||||
@ -152,27 +168,22 @@ func (a *Anonymizer) AnonymizeString(str string) string {
|
|||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
|
||||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
|
||||||
|
|
||||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
|
||||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||||
domainRegex := regexp.MustCompile(domainPattern)
|
parts := strings.SplitN(match, "=", 2)
|
||||||
|
|
||||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
|
||||||
parts := strings.Split(match, `"`)
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
domain := parts[1]
|
domain := parts[1]
|
||||||
if strings.HasSuffix(domain, ".domain") {
|
if strings.HasSuffix(domain, anonTLD) {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
randomDomain := generateRandomString(10) + ".domain"
|
return "domain=" + a.AnonymizeDomain(domain)
|
||||||
return strings.Replace(match, domain, randomDomain, 1)
|
|
||||||
}
|
}
|
||||||
return match
|
return match
|
||||||
})
|
})
|
||||||
|
@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
|
|
||||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
original string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic domain with trailing content",
|
||||||
|
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Domain with trailing dot",
|
||||||
|
input: "domain=example.com. processing request with status=pending",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple domains in log",
|
||||||
|
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||||
|
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||||
|
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Already anonymized domain",
|
||||||
|
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||||
|
original: "", // nothing should be anonymized
|
||||||
|
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Subdomain with trailing dot",
|
||||||
|
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handler chain pattern log",
|
||||||
|
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
for _, tc := range tests {
|
||||||
require.NotEqual(t, testLog, result)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.NotContains(t, result, "example.com")
|
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||||
|
if tc.original != "" {
|
||||||
|
assert.NotContains(t, result, tc.original)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, tc.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeDomain(t *testing.T) {
|
func TestAnonymizeDomain(t *testing.T) {
|
||||||
@ -67,18 +115,36 @@ func TestAnonymizeDomain(t *testing.T) {
|
|||||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Domain with Trailing Dot",
|
||||||
|
"example.com.",
|
||||||
|
`^anon-[a-zA-Z0-9]+\.domain.$`,
|
||||||
|
true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"Subdomain",
|
"Subdomain",
|
||||||
"sub.example.com",
|
"sub.example.com",
|
||||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Subdomain with Trailing Dot",
|
||||||
|
"sub.example.com.",
|
||||||
|
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
|
||||||
|
true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"Protected Domain",
|
"Protected Domain",
|
||||||
"netbird.io",
|
"netbird.io",
|
||||||
`^netbird\.io$`,
|
`^netbird\.io$`,
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Protected Domain with Trailing Dot",
|
||||||
|
"netbird.io.",
|
||||||
|
`^netbird\.io.$`,
|
||||||
|
false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@ -140,8 +206,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
|
|||||||
expect string
|
expect string
|
||||||
}{
|
}{
|
||||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
|
||||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||||
|
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
|
||||||
|
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
|
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
|
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
|
||||||
|
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -61,6 +62,15 @@ var forCmd = &cobra.Command{
|
|||||||
RunE: runForDuration,
|
RunE: runForDuration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var persistenceCmd = &cobra.Command{
|
||||||
|
Use: "persistence [on|off]",
|
||||||
|
Short: "Set network map memory persistence",
|
||||||
|
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
|
||||||
|
Example: " netbird debug persistence on",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: setNetworkMapPersistence,
|
||||||
|
}
|
||||||
|
|
||||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||||
conn, err := getClient(cmd)
|
conn, err := getClient(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -171,6 +181,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
// Enable network map persistence before bringing the service up
|
||||||
|
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: true,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@ -200,6 +217,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disable network map persistence after creating the debug bundle
|
||||||
|
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: false,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
@ -219,6 +243,34 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
persistence := strings.ToLower(args[0])
|
||||||
|
if persistence != "on" && persistence != "off" {
|
||||||
|
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: persistence == "on",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Network map persistence set to: %s\n", persistence)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getStatusOutput(cmd *cobra.Command) string {
|
func getStatusOutput(cmd *cobra.Command) string {
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
statusResp, err := getStatus(cmd.Context())
|
statusResp, err := getStatus(cmd.Context())
|
||||||
|
173
client/cmd/networks.go
Normal file
173
client/cmd/networks.go
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var appendFlag bool
|
||||||
|
|
||||||
|
var networksCMD = &cobra.Command{
|
||||||
|
Use: "networks",
|
||||||
|
Aliases: []string{"routes"},
|
||||||
|
Short: "Manage networks",
|
||||||
|
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List networks",
|
||||||
|
Example: " netbird networks list",
|
||||||
|
Long: "List all available network routes.",
|
||||||
|
RunE: networksList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesSelectCmd = &cobra.Command{
|
||||||
|
Use: "select network...|all",
|
||||||
|
Short: "Select network",
|
||||||
|
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
|
||||||
|
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksSelect,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesDeselectCmd = &cobra.Command{
|
||||||
|
Use: "deselect network...|all",
|
||||||
|
Short: "Deselect networks",
|
||||||
|
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
|
||||||
|
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksDeselect,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksList(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Routes) == 0 {
|
||||||
|
cmd.Println("No networks available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printNetworks(cmd, resp)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||||
|
cmd.Println("Available Networks:")
|
||||||
|
for _, route := range resp.Routes {
|
||||||
|
printNetwork(cmd, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||||
|
selectedStatus := getSelectedStatus(route)
|
||||||
|
domains := route.GetDomains()
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||||
|
} else {
|
||||||
|
printNetworkRoute(cmd, route, selectedStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSelectedStatus(route *proto.Network) string {
|
||||||
|
if route.GetSelected() {
|
||||||
|
return "Selected"
|
||||||
|
}
|
||||||
|
return "Not Selected"
|
||||||
|
}
|
||||||
|
|
||||||
|
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||||
|
resolvedIPs := route.GetResolvedIPs()
|
||||||
|
|
||||||
|
if len(resolvedIPs) > 0 {
|
||||||
|
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||||
|
} else {
|
||||||
|
cmd.Printf(" Resolved IPs: -\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||||
|
cmd.Printf(" Resolved IPs:\n")
|
||||||
|
for resolvedDomain, ipList := range resolvedIPs {
|
||||||
|
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksSelect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
} else if appendFlag {
|
||||||
|
req.Append = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks selected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks deselected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
33
client/cmd/pprof.go
Normal file
33
client/cmd/pprof.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
//go:build pprof
|
||||||
|
// +build pprof
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
addr := pprofAddr()
|
||||||
|
go pprof(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprofAddr() string {
|
||||||
|
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||||
|
if listenAddr == "" {
|
||||||
|
return "localhost:6969"
|
||||||
|
}
|
||||||
|
|
||||||
|
return listenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprof(listenAddr string) {
|
||||||
|
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||||
|
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||||
|
log.Fatalf("Failed to start pprof: %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -142,19 +142,20 @@ func init() {
|
|||||||
rootCmd.AddCommand(loginCmd)
|
rootCmd.AddCommand(loginCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
rootCmd.AddCommand(routesCmd)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||||
|
|
||||||
routesCmd.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
|
|
||||||
debugCmd.AddCommand(debugBundleCmd)
|
debugCmd.AddCommand(debugBundleCmd)
|
||||||
debugCmd.AddCommand(logCmd)
|
debugCmd.AddCommand(logCmd)
|
||||||
logCmd.AddCommand(logLevelCmd)
|
logCmd.AddCommand(logLevelCmd)
|
||||||
debugCmd.AddCommand(forCmd)
|
debugCmd.AddCommand(forCmd)
|
||||||
|
debugCmd.AddCommand(persistenceCmd)
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||||
`Sets external IPs maps between local addresses and interfaces.`+
|
`Sets external IPs maps between local addresses and interfaces.`+
|
||||||
|
@ -1,174 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
var appendFlag bool
|
|
||||||
|
|
||||||
var routesCmd = &cobra.Command{
|
|
||||||
Use: "routes",
|
|
||||||
Short: "Manage network routes",
|
|
||||||
Long: `Commands to list, select, or deselect network routes.`,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesListCmd = &cobra.Command{
|
|
||||||
Use: "list",
|
|
||||||
Aliases: []string{"ls"},
|
|
||||||
Short: "List routes",
|
|
||||||
Example: " netbird routes list",
|
|
||||||
Long: "List all available network routes.",
|
|
||||||
RunE: routesList,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesSelectCmd = &cobra.Command{
|
|
||||||
Use: "select route...|all",
|
|
||||||
Short: "Select routes",
|
|
||||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
|
||||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesSelect,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesDeselectCmd = &cobra.Command{
|
|
||||||
Use: "deselect route...|all",
|
|
||||||
Short: "Deselect routes",
|
|
||||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
|
||||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesDeselect,
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesList(cmd *cobra.Command, _ []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Routes) == 0 {
|
|
||||||
cmd.Println("No routes available.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
printRoutes(cmd, resp)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
|
||||||
cmd.Println("Available Routes:")
|
|
||||||
for _, route := range resp.Routes {
|
|
||||||
printRoute(cmd, route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
|
||||||
selectedStatus := getSelectedStatus(route)
|
|
||||||
domains := route.GetDomains()
|
|
||||||
|
|
||||||
if len(domains) > 0 {
|
|
||||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
|
||||||
} else {
|
|
||||||
printNetworkRoute(cmd, route, selectedStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSelectedStatus(route *proto.Route) string {
|
|
||||||
if route.GetSelected() {
|
|
||||||
return "Selected"
|
|
||||||
}
|
|
||||||
return "Not Selected"
|
|
||||||
}
|
|
||||||
|
|
||||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
|
||||||
resolvedIPs := route.GetResolvedIPs()
|
|
||||||
|
|
||||||
if len(resolvedIPs) > 0 {
|
|
||||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
|
||||||
} else {
|
|
||||||
cmd.Printf(" Resolved IPs: -\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
|
||||||
cmd.Printf(" Resolved IPs:\n")
|
|
||||||
for _, domain := range domains {
|
|
||||||
if ipList, exists := resolvedIPs[domain]; exists {
|
|
||||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
} else if appendFlag {
|
|
||||||
req.Append = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes selected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes deselected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
181
client/cmd/state.go
Normal file
181
client/cmd/state.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
allFlag bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var stateCmd = &cobra.Command{
|
||||||
|
Use: "state",
|
||||||
|
Short: "Manage daemon state",
|
||||||
|
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all stored states",
|
||||||
|
Long: "Lists all registered states with their status and basic information.",
|
||||||
|
Example: " netbird state list",
|
||||||
|
RunE: stateList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateCleanCmd = &cobra.Command{
|
||||||
|
Use: "clean [state-name]",
|
||||||
|
Short: "Clean stored states",
|
||||||
|
Long: `Clean specific state or all states. The daemon must not be running.
|
||||||
|
This will perform cleanup operations and remove the state.`,
|
||||||
|
Example: ` netbird state clean dns_state
|
||||||
|
netbird state clean --all`,
|
||||||
|
RunE: stateClean,
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Check mutual exclusivity between --all flag and state-name argument
|
||||||
|
if allFlag && len(args) > 0 {
|
||||||
|
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||||
|
}
|
||||||
|
if !allFlag && len(args) != 1 {
|
||||||
|
return fmt.Errorf("requires a state name argument or --all flag")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateDeleteCmd = &cobra.Command{
|
||||||
|
Use: "delete [state-name]",
|
||||||
|
Short: "Delete stored states",
|
||||||
|
Long: `Delete specific state or all states from storage. The daemon must not be running.
|
||||||
|
This will remove the state without performing any cleanup operations.`,
|
||||||
|
Example: ` netbird state delete dns_state
|
||||||
|
netbird state delete --all`,
|
||||||
|
RunE: stateDelete,
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Check mutual exclusivity between --all flag and state-name argument
|
||||||
|
if allFlag && len(args) > 0 {
|
||||||
|
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||||
|
}
|
||||||
|
if !allFlag && len(args) != 1 {
|
||||||
|
return fmt.Errorf("requires a state name argument or --all flag")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(stateCmd)
|
||||||
|
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
|
||||||
|
|
||||||
|
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
|
||||||
|
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateList(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("\nStored states:\n\n")
|
||||||
|
for _, state := range resp.States {
|
||||||
|
cmd.Printf("- %s\n", state.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateClean(cmd *cobra.Command, args []string) error {
|
||||||
|
var stateName string
|
||||||
|
if !allFlag {
|
||||||
|
stateName = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
|
||||||
|
StateName: stateName,
|
||||||
|
All: allFlag,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.CleanedStates == 0 {
|
||||||
|
cmd.Println("No states were cleaned")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if allFlag {
|
||||||
|
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Successfully cleaned state %q\n", stateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateDelete(cmd *cobra.Command, args []string) error {
|
||||||
|
var stateName string
|
||||||
|
if !allFlag {
|
||||||
|
stateName = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
|
||||||
|
StateName: stateName,
|
||||||
|
All: allFlag,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.DeletedStates == 0 {
|
||||||
|
cmd.Println("No states were deleted")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if allFlag {
|
||||||
|
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Successfully deleted state %q\n", stateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -40,6 +40,7 @@ type peerStateDetailOutput struct {
|
|||||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type peersStateOutput struct {
|
type peersStateOutput struct {
|
||||||
@ -98,6 +99,7 @@ type statusOutputOverview struct {
|
|||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
|||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||||
|
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
TransferSent: transferSent,
|
TransferSent: transferSent,
|
||||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||||
Routes: pbPeerState.GetRoutes(),
|
Routes: pbPeerState.GetNetworks(),
|
||||||
|
Networks: pbPeerState.GetNetworks(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peersStateDetail = append(peersStateDetail, peerState)
|
peersStateDetail = append(peersStateDetail, peerState)
|
||||||
@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := "-"
|
networks := "-"
|
||||||
if len(overview.Routes) > 0 {
|
if len(overview.Networks) > 0 {
|
||||||
sort.Strings(overview.Routes)
|
sort.Strings(overview.Networks)
|
||||||
routes = strings.Join(overview.Routes, ", ")
|
networks = strings.Join(overview.Networks, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
var dnsServersString string
|
var dnsServersString string
|
||||||
@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Routes: %s\n"+
|
"Routes: %s\n"+
|
||||||
|
"Networks: %s\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||||
overview.DaemonVersion,
|
overview.DaemonVersion,
|
||||||
@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
routes,
|
networks,
|
||||||
|
networks,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := "-"
|
networks := "-"
|
||||||
if len(peerState.Routes) > 0 {
|
if len(peerState.Networks) > 0 {
|
||||||
sort.Strings(peerState.Routes)
|
sort.Strings(peerState.Networks)
|
||||||
routes = strings.Join(peerState.Routes, ", ")
|
networks = strings.Join(peerState.Networks, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
peerString := fmt.Sprintf(
|
||||||
@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
" Transfer status (received/sent) %s/%s\n"+
|
" Transfer status (received/sent) %s/%s\n"+
|
||||||
" Quantum resistance: %s\n"+
|
" Quantum resistance: %s\n"+
|
||||||
" Routes: %s\n"+
|
" Routes: %s\n"+
|
||||||
|
" Networks: %s\n"+
|
||||||
" Latency: %s\n",
|
" Latency: %s\n",
|
||||||
peerState.FQDN,
|
peerState.FQDN,
|
||||||
peerState.IP,
|
peerState.IP,
|
||||||
@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
toIEC(peerState.TransferReceived),
|
toIEC(peerState.TransferReceived),
|
||||||
toIEC(peerState.TransferSent),
|
toIEC(peerState.TransferSent),
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
routes,
|
networks,
|
||||||
|
networks,
|
||||||
peerState.Latency.String(),
|
peerState.Latency.String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|||||||
|
|
||||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||||
|
|
||||||
|
for i, route := range peer.Networks {
|
||||||
|
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, route := range peer.Networks {
|
||||||
|
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||||
|
}
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
for i, route := range peer.Routes {
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||||
}
|
}
|
||||||
@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, route := range overview.Networks {
|
||||||
|
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||||
|
}
|
||||||
|
|
||||||
for i, route := range overview.Routes {
|
for i, route := range overview.Routes {
|
||||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{
|
|||||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||||
BytesRx: 200,
|
BytesRx: 200,
|
||||||
BytesTx: 100,
|
BytesTx: 100,
|
||||||
Routes: []string{
|
Networks: []string{
|
||||||
"10.1.0.0/24",
|
"10.1.0.0/24",
|
||||||
},
|
},
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
Latency: durationpb.New(time.Duration(10000000)),
|
||||||
@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{
|
|||||||
PubKey: "Some-Pub-Key",
|
PubKey: "Some-Pub-Key",
|
||||||
KernelInterface: true,
|
KernelInterface: true,
|
||||||
Fqdn: "some-localhost.awesome-domain.com",
|
Fqdn: "some-localhost.awesome-domain.com",
|
||||||
Routes: []string{
|
Networks: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -149,6 +149,9 @@ var overview = statusOutputOverview{
|
|||||||
Routes: []string{
|
Routes: []string{
|
||||||
"10.1.0.0/24",
|
"10.1.0.0/24",
|
||||||
},
|
},
|
||||||
|
Networks: []string{
|
||||||
|
"10.1.0.0/24",
|
||||||
|
},
|
||||||
Latency: time.Duration(10000000),
|
Latency: time.Duration(10000000),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -230,6 +233,9 @@ var overview = statusOutputOverview{
|
|||||||
Routes: []string{
|
Routes: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
|
Networks: []string{
|
||||||
|
"10.10.0.0/24",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": [
|
"routes": [
|
||||||
"10.1.0.0/24"
|
"10.1.0.0/24"
|
||||||
|
],
|
||||||
|
"networks": [
|
||||||
|
"10.1.0.0/24"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"transferSent": 1000,
|
"transferSent": 1000,
|
||||||
"latency": 10000000,
|
"latency": 10000000,
|
||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": null
|
"routes": null,
|
||||||
|
"networks": null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"routes": [
|
"routes": [
|
||||||
"10.10.0.0/24"
|
"10.10.0.0/24"
|
||||||
],
|
],
|
||||||
|
"networks": [
|
||||||
|
"10.10.0.0/24"
|
||||||
|
],
|
||||||
"dnsServers": [
|
"dnsServers": [
|
||||||
{
|
{
|
||||||
"servers": [
|
"servers": [
|
||||||
@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes:
|
routes:
|
||||||
- 10.1.0.0/24
|
- 10.1.0.0/24
|
||||||
|
networks:
|
||||||
|
- 10.1.0.0/24
|
||||||
- fqdn: peer-2.awesome-domain.com
|
- fqdn: peer-2.awesome-domain.com
|
||||||
netbirdIp: 192.168.178.102
|
netbirdIp: 192.168.178.102
|
||||||
publicKey: Pubkey2
|
publicKey: Pubkey2
|
||||||
@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
latency: 10ms
|
latency: 10ms
|
||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes: []
|
routes: []
|
||||||
|
networks: []
|
||||||
cliVersion: development
|
cliVersion: development
|
||||||
daemonVersion: 0.14.1
|
daemonVersion: 0.14.1
|
||||||
management:
|
management:
|
||||||
@ -465,6 +481,8 @@ quantumResistance: false
|
|||||||
quantumResistancePermissive: false
|
quantumResistancePermissive: false
|
||||||
routes:
|
routes:
|
||||||
- 10.10.0.0/24
|
- 10.10.0.0/24
|
||||||
|
networks:
|
||||||
|
- 10.10.0.0/24
|
||||||
dnsServers:
|
dnsServers:
|
||||||
- servers:
|
- servers:
|
||||||
- 8.8.8.8:53
|
- 8.8.8.8:53
|
||||||
@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Transfer status (received/sent) 200 B/100 B
|
Transfer status (received/sent) 200 B/100 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.1.0.0/24
|
Routes: 10.1.0.0/24
|
||||||
|
Networks: 10.1.0.0/24
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
peer-2.awesome-domain.com:
|
peer-2.awesome-domain.com:
|
||||||
@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: -
|
Routes: -
|
||||||
|
Networks: -
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
OS: %s/%s
|
OS: %s/%s
|
||||||
@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
Routes: 10.10.0.0/24
|
||||||
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||||
|
|
||||||
@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
Routes: 10.10.0.0/24
|
||||||
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`
|
`
|
||||||
|
|
||||||
|
31
client/cmd/system.go
Normal file
31
client/cmd/system.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
// Flag constants for system configuration
|
||||||
|
const (
|
||||||
|
disableClientRoutesFlag = "disable-client-routes"
|
||||||
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
|
disableDNSFlag = "disable-dns"
|
||||||
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
disableClientRoutes bool
|
||||||
|
disableServerRoutes bool
|
||||||
|
disableDNS bool
|
||||||
|
disableFirewall bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Add system flags to upCmd
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||||
|
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||||
|
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||||
|
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
}
|
@ -10,6 +10,8 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -147,6 +147,19 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.DNSRouteInterval = &dnsRouteInterval
|
ic.DNSRouteInterval = &dnsRouteInterval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
ic.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
ic.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
ic.DisableDNS = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
ic.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -264,6 +277,19 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
loginRequest.DisableDns = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
loginRequest.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
24
client/configs/configs.go
Normal file
24
client/configs/configs.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package configs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
var StateDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
StateDir = os.Getenv("NB_STATE_DIR")
|
||||||
|
if StateDir != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||||
|
case "darwin", "linux":
|
||||||
|
StateDir = "/var/lib/netbird"
|
||||||
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
|
StateDir = "/var/db/netbird"
|
||||||
|
}
|
||||||
|
}
|
@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
|
|
||||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||||
func (m *aclManager) seedInitialEntries() {
|
func (m *aclManager) seedInitialEntries() {
|
||||||
|
|
||||||
established := getConntrackEstablished()
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||||
|
@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
go func() {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
}
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -195,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@ -205,19 +207,9 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
_, err = m.AddPeerFiltering(
|
return nil
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.RuleDirectionOUT,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ips = temp.IPs
|
s.ips = temp.IPs
|
||||||
|
|
||||||
|
if temp.IPs == nil {
|
||||||
|
temp.IPs = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ipsets = temp.IPSets
|
s.ipsets = temp.IPSets
|
||||||
|
|
||||||
|
if temp.IPSets == nil {
|
||||||
|
temp.IPSets = make(map[string]*ipList)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -28,7 +27,6 @@ const (
|
|||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNamePrerouting = "netbird-rt-prerouting"
|
||||||
|
|
||||||
@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// netbird-acl-output-filter
|
|
||||||
// type filter hook output priority filter; policy accept;
|
|
||||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
|
||||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
|
||||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
|
||||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
// netbird-acl-forward-filter
|
||||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
dstOp := expr.CmpOpNeq
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: iifname, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: dstOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
|
@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early
|
// persist early
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
go func() {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
}
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
var chain *nftables.Chain
|
var chain *nftables.Chain
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
chain = c
|
chain = c
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
|||||||
|
|
||||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Verdict{},
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
UserData: []byte(allowNetbirdInputRuleID),
|
UserData: []byte(allowNetbirdInputRuleID),
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runIptablesSave(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd := exec.Command("iptables-save")
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err, "iptables-save failed to run")
|
||||||
|
|
||||||
|
return stdout.String(), stderr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||||
|
t.Helper()
|
||||||
|
// Check for any incompatibility warnings
|
||||||
|
require.NotContains(t,
|
||||||
|
stderr,
|
||||||
|
"incompatible",
|
||||||
|
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||||
|
stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify standard tables are present
|
||||||
|
expectedTables := []string{
|
||||||
|
"*filter",
|
||||||
|
"*nat",
|
||||||
|
"*mangle",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range expectedTables {
|
||||||
|
require.Contains(t,
|
||||||
|
stdout,
|
||||||
|
table,
|
||||||
|
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||||
|
table,
|
||||||
|
stdout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Reset(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := net.ParseIP("100.96.0.1")
|
||||||
|
_, err = manager.AddPeerFiltering(
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
fw.RuleDirectionIN,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"test rule",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
pair := fw.RouterPair{
|
||||||
|
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
err = manager.AddNatRule(pair)
|
||||||
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
@ -1 +0,0 @@
|
|||||||
package nftables
|
|
@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset(stateManager)
|
return m.nativeFirewall.Reset(stateManager)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
137
client/firewall/uspfilter/conntrack/common.go
Normal file
137
client/firewall/uspfilter/conntrack/common.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
// common.go
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
|
type BaseConnTrack struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
DestIP net.IP
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
|
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||||
|
established atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// these small methods will be inlined by the compiler
|
||||||
|
|
||||||
|
// UpdateLastSeen safely updates the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) UpdateLastSeen() {
|
||||||
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (b *BaseConnTrack) IsEstablished() bool {
|
||||||
|
return b.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||||
|
b.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeoutExceeded checks if the connection has exceeded the given timeout
|
||||||
|
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||||
|
lastSeen := time.Unix(0, b.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen) > timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddr is a fixed-size IP address to avoid allocations
|
||||||
|
type IPAddr [16]byte
|
||||||
|
|
||||||
|
// MakeIPAddr creates an IPAddr from net.IP
|
||||||
|
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||||
|
// Optimization: check for v4 first as it's more common
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
copy(addr[12:], ip4)
|
||||||
|
} else {
|
||||||
|
copy(addr[:], ip.To16())
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnKey uniquely identifies a connection
|
||||||
|
type ConnKey struct {
|
||||||
|
SrcIP IPAddr
|
||||||
|
DstIP IPAddr
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeConnKey creates a connection key
|
||||||
|
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||||
|
return ConnKey{
|
||||||
|
SrcIP: MakeIPAddr(srcIP),
|
||||||
|
DstIP: MakeIPAddr(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateIPs checks if IPs match without allocation
|
||||||
|
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||||
|
if ip4 := pktIP.To4(); ip4 != nil {
|
||||||
|
// Compare IPv4 addresses (last 4 bytes)
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
if connIP[12+i] != ip4[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Compare full IPv6 addresses
|
||||||
|
ip6 := pktIP.To16()
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
if connIP[i] != ip6[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||||
|
type PreallocatedIPs struct {
|
||||||
|
sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreallocatedIPs creates a new IP pool
|
||||||
|
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||||
|
return &PreallocatedIPs{
|
||||||
|
Pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
ip := make(net.IP, 16)
|
||||||
|
return &ip
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an IP from the pool
|
||||||
|
func (p *PreallocatedIPs) Get() net.IP {
|
||||||
|
return *p.Pool.Get().(*net.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put returns an IP to the pool
|
||||||
|
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||||
|
p.Pool.Put(&ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyIP copies an IP address efficiently
|
||||||
|
func copyIP(dst, src net.IP) {
|
||||||
|
if len(src) == 16 {
|
||||||
|
copy(dst, src)
|
||||||
|
} else {
|
||||||
|
// Handle IPv4
|
||||||
|
copy(dst[12:], src.To4())
|
||||||
|
}
|
||||||
|
}
|
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkIPOperations(b *testing.B) {
|
||||||
|
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = MakeIPAddr(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ValidateIPs", func(b *testing.B) {
|
||||||
|
ip1 := net.ParseIP("192.168.1.1")
|
||||||
|
ip2 := net.ParseIP("192.168.1.1")
|
||||||
|
addr := MakeIPAddr(ip1)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ValidateIPs(addr, ip2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IPPool", func(b *testing.B) {
|
||||||
|
pool := NewPreallocatedIPs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ip := pool.Get()
|
||||||
|
pool.Put(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
func BenchmarkAtomicOperations(b *testing.B) {
|
||||||
|
conn := &BaseConnTrack{}
|
||||||
|
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsEstablished", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = conn.IsEstablished()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("SetEstablished", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn.SetEstablished(i%2 == 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("GetLastSeen", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = conn.GetLastSeen()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memory pressure tests
|
||||||
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]net.IP, 100)
|
||||||
|
dstIPs := make([]net.IP, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||||
|
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]net.IP, 100)
|
||||||
|
dstIPs := make([]net.IP, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||||
|
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultICMPTimeout is the default timeout for ICMP connections
|
||||||
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
|
type ICMPConnKey struct {
|
||||||
|
// Supports both IPv4 and IPv6
|
||||||
|
SrcIP [16]byte
|
||||||
|
DstIP [16]byte
|
||||||
|
Sequence uint16 // ICMP sequence number
|
||||||
|
ID uint16 // ICMP identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
|
type ICMPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
Sequence uint16
|
||||||
|
ID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPTracker manages ICMP connection states
|
||||||
|
type ICMPTracker struct {
|
||||||
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
mutex sync.RWMutex
|
||||||
|
done chan struct{}
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
|
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultICMPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := &ICMPTracker{
|
||||||
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
},
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(true)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
|
switch icmpType {
|
||||||
|
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||||
|
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||||
|
return true
|
||||||
|
case uint8(layers.ICMPv4TypeEchoReply):
|
||||||
|
// continue processing
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.IsEstablished() &&
|
||||||
|
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
|
conn.ID == id &&
|
||||||
|
conn.Sequence == seq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (t *ICMPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *ICMPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeICMPKey creates an ICMP connection key
|
||||||
|
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||||
|
return ICMPConnKey{
|
||||||
|
SrcIP: MakeIPAddr(srcIP),
|
||||||
|
DstIP: MakeIPAddr(dstIP),
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
}
|
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MSL (Maximum Segment Lifetime) is typically 2 minutes
|
||||||
|
MSL = 2 * time.Minute
|
||||||
|
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
|
||||||
|
TimeWaitTimeout = 2 * MSL
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
|
TCPFin uint8 = 0x01
|
||||||
|
TCPRst uint8 = 0x04
|
||||||
|
TCPPush uint8 = 0x08
|
||||||
|
TCPUrg uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTCPTimeout is the default timeout for established TCP connections
|
||||||
|
DefaultTCPTimeout = 3 * time.Hour
|
||||||
|
// TCPHandshakeTimeout is timeout for TCP handshake completion
|
||||||
|
TCPHandshakeTimeout = 60 * time.Second
|
||||||
|
// TCPCleanupInterval is how often we check for stale connections
|
||||||
|
TCPCleanupInterval = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPState represents the state of a TCP connection
|
||||||
|
type TCPState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPStateNew TCPState = iota
|
||||||
|
TCPStateSynSent
|
||||||
|
TCPStateSynReceived
|
||||||
|
TCPStateEstablished
|
||||||
|
TCPStateFinWait1
|
||||||
|
TCPStateFinWait2
|
||||||
|
TCPStateClosing
|
||||||
|
TCPStateTimeWait
|
||||||
|
TCPStateCloseWait
|
||||||
|
TCPStateLastAck
|
||||||
|
TCPStateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPConnKey uniquely identifies a TCP connection
|
||||||
|
type TCPConnKey struct {
|
||||||
|
SrcIP [16]byte
|
||||||
|
DstIP [16]byte
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPConnTrack represents a TCP connection state
|
||||||
|
type TCPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
State TCPState
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPTracker manages TCP connection states
|
||||||
|
type TCPTracker struct {
|
||||||
|
connections map[ConnKey]*TCPConnTrack
|
||||||
|
mutex sync.RWMutex
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
done chan struct{}
|
||||||
|
timeout time.Duration
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
|
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||||
|
tracker := &TCPTracker{
|
||||||
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
timeout: timeout,
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
|
// Create key before lock
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
// Use preallocated IPs
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
State: TCPStateNew,
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(false)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
// Lock individual connection for state update
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(conn, flags, true)
|
||||||
|
conn.Unlock()
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
|
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||||
|
if !isValidFlagCombination(flags) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle RST packets
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
conn.Lock()
|
||||||
|
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
conn.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
conn.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(conn, flags, false)
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
isEstablished := conn.IsEstablished()
|
||||||
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
return isEstablished || isValidState
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateState updates the TCP connection state based on flags
|
||||||
|
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch conn.State {
|
||||||
|
case TCPStateNew:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
|
conn.State = TCPStateSynSent
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynSent:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateSynReceived
|
||||||
|
} else {
|
||||||
|
// Simultaneous open
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateFinWait1
|
||||||
|
} else {
|
||||||
|
conn.State = TCPStateCloseWait
|
||||||
|
}
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
switch {
|
||||||
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
|
// Simultaneous close - both sides sent FIN
|
||||||
|
conn.State = TCPStateClosing
|
||||||
|
case flags&TCPFin != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPAck != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateClosing:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
// Keep established = false from previous state
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateLastAck
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateLastAck:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||||
|
// This is handled by the cleanup routine
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||||
|
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||||
|
if !isValidFlagCombination(flags) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case TCPStateNew:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateClosing:
|
||||||
|
// In CLOSING state, we should accept the final ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
// In TIME_WAIT, we might see retransmissions
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateClosed:
|
||||||
|
// Accept retransmitted ACKs in closed state
|
||||||
|
// This is important because the final ACK might be lost
|
||||||
|
// and the peer will retransmit their FIN-ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
var timeout time.Duration
|
||||||
|
switch {
|
||||||
|
case conn.State == TCPStateTimeWait:
|
||||||
|
timeout = TimeWaitTimeout
|
||||||
|
case conn.IsEstablished():
|
||||||
|
timeout = t.timeout
|
||||||
|
default:
|
||||||
|
timeout = TCPHandshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
lastSeen := conn.GetLastSeen()
|
||||||
|
if time.Since(lastSeen) > timeout {
|
||||||
|
// Return IPs to pool
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *TCPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
// Clean up all remaining IPs
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidFlagCombination(flags uint8) bool {
|
||||||
|
// Invalid: SYN+FIN
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid: RST with SYN or FIN
|
||||||
|
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
|
dstIP := net.ParseIP("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Security Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
flags uint8
|
||||||
|
wantDrop bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Block unsolicited SYN-ACK",
|
||||||
|
flags: TCPSyn | TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block SYN-ACK without prior SYN",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block invalid SYN-FIN",
|
||||||
|
flags: TCPSyn | TCPFin,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block invalid SYN-FIN combination",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited RST",
|
||||||
|
flags: TCPRst,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block RST without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited ACK",
|
||||||
|
flags: TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block ACK without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block data without connection",
|
||||||
|
flags: TCPAck | TCPPush,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block data without established connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||||
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Connection Flow Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
test func(*testing.T)
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal Handshake",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Send initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
|
||||||
|
// Receive SYN-ACK
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
// Send ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
|
||||||
|
// Test data transfer
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||||
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Normal Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||||
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
|
// Receive FIN from other side
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||||
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST During Connection",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Receive RST
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
|
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||||
|
// The connection will be cleaned up by timeout
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simultaneous Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Both sides send FIN+ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||||
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
|
// Both sides send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||||
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
tt.test(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRSTHandling(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
|
dstIP := net.ParseIP("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupState func()
|
||||||
|
sendRST func()
|
||||||
|
wantValid bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RST in established",
|
||||||
|
setupState: func() {
|
||||||
|
// Establish connection first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
},
|
||||||
|
wantValid: true,
|
||||||
|
desc: "Should accept RST for established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST without connection",
|
||||||
|
setupState: func() {},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
},
|
||||||
|
wantValid: false,
|
||||||
|
desc: "Should reject RST without connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setupState()
|
||||||
|
tt.sendRST()
|
||||||
|
|
||||||
|
// Verify connection state is as expected
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
if tt.wantValid {
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.State)
|
||||||
|
require.False(t, conn.IsEstablished())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to establish a TCP connection
|
||||||
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultUDPTimeout is the default timeout for UDP connections
|
||||||
|
DefaultUDPTimeout = 30 * time.Second
|
||||||
|
// UDPCleanupInterval is how often we check for stale connections
|
||||||
|
UDPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDPConnTrack represents a UDP connection state
|
||||||
|
type UDPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPTracker manages UDP connection states
|
||||||
|
type UDPTracker struct {
|
||||||
|
connections map[ConnKey]*UDPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
mutex sync.RWMutex
|
||||||
|
done chan struct{}
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
|
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultUDPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound UDP connection
|
||||||
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(true)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||||
|
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.IsEstablished() &&
|
||||||
|
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
|
conn.DestPort == srcPort &&
|
||||||
|
conn.SourcePort == dstPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupRoutine periodically removes stale connections
|
||||||
|
func (t *UDPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *UDPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnection safely retrieves a connection state
|
||||||
|
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
|
t.mutex.RLock()
|
||||||
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
|
return t.timeout
|
||||||
|
}
|
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUDPTracker(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
wantTimeout time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with custom timeout",
|
||||||
|
timeout: 1 * time.Minute,
|
||||||
|
wantTimeout: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with zero timeout uses default",
|
||||||
|
timeout: 0,
|
||||||
|
wantTimeout: DefaultUDPTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(tt.timeout)
|
||||||
|
assert.NotNil(t, tracker)
|
||||||
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
|
assert.NotNil(t, tracker.connections)
|
||||||
|
assert.NotNil(t, tracker.cleanupTicker)
|
||||||
|
assert.NotNil(t, tracker.done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn, exists := tracker.connections[key]
|
||||||
|
require.True(t, exists)
|
||||||
|
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||||
|
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||||
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
|
assert.True(t, conn.IsEstablished())
|
||||||
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(1 * time.Second)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Track outbound connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
srcIP net.IP
|
||||||
|
dstIP net.IP
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
sleep time.Duration
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid inbound response",
|
||||||
|
srcIP: dstIP, // Original destination is now source
|
||||||
|
dstIP: srcIP, // Original source is now destination
|
||||||
|
srcPort: dstPort, // Original destination port is now source
|
||||||
|
dstPort: srcPort, // Original source port is now destination
|
||||||
|
sleep: 0,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source IP",
|
||||||
|
srcIP: net.ParseIP("192.168.1.4"),
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination IP",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: net.ParseIP("192.168.1.4"),
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: 54321,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: 54321,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired connection",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 2 * time.Second, // Longer than tracker timeout
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.sleep > 0 {
|
||||||
|
time.Sleep(tt.sleep)
|
||||||
|
}
|
||||||
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_Cleanup(t *testing.T) {
|
||||||
|
// Use shorter intervals for testing
|
||||||
|
timeout := 50 * time.Millisecond
|
||||||
|
cleanupInterval := 25 * time.Millisecond
|
||||||
|
|
||||||
|
// Create tracker with custom cleanup interval
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start cleanup routine
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
|
||||||
|
// Add some connections
|
||||||
|
connections := []struct {
|
||||||
|
srcIP net.IP
|
||||||
|
dstIP net.IP
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
srcIP: net.ParseIP("192.168.1.2"),
|
||||||
|
dstIP: net.ParseIP("192.168.1.3"),
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
srcIP: net.ParseIP("192.168.1.4"),
|
||||||
|
dstIP: net.ParseIP("192.168.1.5"),
|
||||||
|
srcPort: 12346,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conn := range connections {
|
||||||
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial connections
|
||||||
|
assert.Len(t, tracker.connections, 2)
|
||||||
|
|
||||||
|
// Wait for connection timeout and cleanup interval
|
||||||
|
time.Sleep(timeout + 2*cleanupInterval)
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
connCount := len(tracker.connections)
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Verify connections were cleaned up
|
||||||
|
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
|
||||||
|
|
||||||
|
// Properly close the tracker
|
||||||
|
tracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -12,6 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@ -19,6 +22,8 @@ import (
|
|||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
|
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||||
)
|
)
|
||||||
@ -42,6 +47,11 @@ type Manager struct {
|
|||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
stateful bool
|
||||||
|
udpTracker *conntrack.UDPTracker
|
||||||
|
icmpTracker *conntrack.ICMPTracker
|
||||||
|
tcpTracker *conntrack.TCPTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -73,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func create(iface IFaceMapper) (*Manager, error) {
|
||||||
|
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
@ -90,6 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
|
stateful: !disableConntrack,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only initialize trackers if stateful mode is enabled
|
||||||
|
if disableConntrack {
|
||||||
|
log.Info("conntrack is disabled")
|
||||||
|
} else {
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@ -239,7 +261,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return errRouteNotSupported
|
return nil
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||||
}
|
}
|
||||||
@ -249,16 +271,16 @@ func (m *Manager) Flush() error { return nil }
|
|||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
return m.processOutgoingHooks(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.incomingRules, true)
|
return m.dropFilter(packetData, m.incomingRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements same logic for booth direction of the traffic
|
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@ -266,61 +288,215 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
|
|||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
return false
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
if srcIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always process UDP hooks
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
// Track UDP state only if enabled
|
||||||
|
if m.stateful {
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track other protocols only if stateful mode is enabled
|
||||||
|
if m.stateful {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
return d.ip4.SrcIP, d.ip4.DstIP
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
return d.ip6.SrcIP, d.ip6.DstIP
|
||||||
|
default:
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.tcp.SrcPort),
|
||||||
|
uint16(d.tcp.DstPort),
|
||||||
|
flags,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
|
var flags uint8
|
||||||
|
if tcp.SYN {
|
||||||
|
flags |= conntrack.TCPSyn
|
||||||
|
}
|
||||||
|
if tcp.ACK {
|
||||||
|
flags |= conntrack.TCPAck
|
||||||
|
}
|
||||||
|
if tcp.FIN {
|
||||||
|
flags |= conntrack.TCPFin
|
||||||
|
}
|
||||||
|
if tcp.RST {
|
||||||
|
flags |= conntrack.TCPRst
|
||||||
|
}
|
||||||
|
if tcp.PSH {
|
||||||
|
flags |= conntrack.TCPPush
|
||||||
|
}
|
||||||
|
if tcp.URG {
|
||||||
|
flags |= conntrack.TCPUrg
|
||||||
|
}
|
||||||
|
return flags
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
m.udpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.udp.SrcPort),
|
||||||
|
uint16(d.udp.DstPort),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||||
|
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||||
|
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||||
|
m.icmpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
d.icmp4.Id,
|
||||||
|
d.icmp4.Seq,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dropFilter implements filtering logic for incoming packets
|
||||||
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||||
|
// TODO: Disable router if --disable-server-router is set
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
if !m.isValidPacket(d, packetData) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
ipLayer := d.decoded[0]
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
if srcIP == nil {
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var ip net.IP
|
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||||
switch ipLayer {
|
return false
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip4.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip4.DstIP
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip6.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip6.DstIP
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
// Check connection state only if enabled
|
||||||
if ok {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return filter
|
return false
|
||||||
}
|
}
|
||||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
|
||||||
if ok {
|
return m.applyRules(srcIP, packetData, rules, d)
|
||||||
return filter
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
log.Tracef("couldn't decode layer, err: %s", err)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
|
||||||
if ok {
|
if len(d.decoded) < 2 {
|
||||||
|
log.Tracef("not enough levels in network packet")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||||
|
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return m.tcpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.tcp.SrcPort),
|
||||||
|
uint16(d.tcp.DstPort),
|
||||||
|
getTCPFlags(&d.tcp),
|
||||||
|
)
|
||||||
|
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return m.udpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.udp.SrcPort),
|
||||||
|
uint16(d.udp.DstPort),
|
||||||
|
)
|
||||||
|
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
return m.icmpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
d.icmp4.Id,
|
||||||
|
d.icmp4.Seq,
|
||||||
|
d.icmp4.TypeCode.Type(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: ICMPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
|
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||||
return filter
|
return filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// default policy is DROP ALL
|
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default policy: DROP ALL
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
@ -0,0 +1,998 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range
|
||||||
|
func generateRandomIPs(n int) []net.IP {
|
||||||
|
ips := make([]net.IP, n)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for i := 0; i < n; {
|
||||||
|
ip := make(net.IP, 4)
|
||||||
|
ip[0] = 100
|
||||||
|
ip[1] = byte(64 + rand.Intn(63)) // 64-126
|
||||||
|
ip[2] = byte(rand.Intn(256))
|
||||||
|
ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255
|
||||||
|
|
||||||
|
key := ip.String()
|
||||||
|
if !seen[key] {
|
||||||
|
ips[i] = ip
|
||||||
|
seen[key] = true
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: protocol,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch protocol {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = udp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||||
|
require.NoError(b, err)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCoreFiltering focuses on the essential performance comparisons between
|
||||||
|
// stateful and stateless filtering approaches
|
||||||
|
func BenchmarkCoreFiltering(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
stateful bool
|
||||||
|
setupFunc func(*Manager)
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stateless_single_allow_all",
|
||||||
|
stateful: false,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Single rule allowing all traffic
|
||||||
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
|
||||||
|
require.NoError(b, err)
|
||||||
|
},
|
||||||
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules",
|
||||||
|
stateful: true,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// No explicit rules - rely purely on connection tracking
|
||||||
|
},
|
||||||
|
desc: "Pure connection tracking without any rules",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateless_explicit_return",
|
||||||
|
stateful: false,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Add explicit rules matching return traffic pattern
|
||||||
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
|
ip := generateRandomIPs(1)[0]
|
||||||
|
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []int{1024 + i}},
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
desc: "Explicit rules matching return traffic patterns without state",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_established",
|
||||||
|
stateful: true,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Add some basic rules but rely on state for established connections
|
||||||
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
|
||||||
|
require.NoError(b, err)
|
||||||
|
},
|
||||||
|
desc: "Connection tracking with established connections",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test both TCP and UDP
|
||||||
|
protocols := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
}{
|
||||||
|
{"TCP", layers.IPProtocolTCP},
|
||||||
|
{"UDP", layers.IPProtocolUDP},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
for _, proto := range protocols {
|
||||||
|
b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1"))
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create manager and basic setup
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply scenario-specific setup
|
||||||
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
|
// Generate test packets
|
||||||
|
srcIP := generateRandomIPs(1)[0]
|
||||||
|
dstIP := generateRandomIPs(1)[0]
|
||||||
|
srcPort := uint16(1024 + b.N%60000)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto)
|
||||||
|
inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto)
|
||||||
|
|
||||||
|
// For stateful scenarios, establish the connection
|
||||||
|
if sc.stateful {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Measure inbound packet processing
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkStateScaling measures how performance scales with connection table size
|
||||||
|
func BenchmarkStateScaling(b *testing.B) {
|
||||||
|
connCounts := []int{100, 1000, 10000, 100000}
|
||||||
|
|
||||||
|
for _, count := range connCounts {
|
||||||
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate connection table
|
||||||
|
srcIPs := generateRandomIPs(count)
|
||||||
|
dstIPs := generateRandomIPs(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test packet
|
||||||
|
testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP)
|
||||||
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
|
// First establish our test connection
|
||||||
|
manager.processOutgoingHooks(testOut)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(testIn, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkEstablishmentOverhead measures the overhead of connection establishment
|
||||||
|
func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
established bool
|
||||||
|
}{
|
||||||
|
{"established", true},
|
||||||
|
{"new", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := generateRandomIPs(1)[0]
|
||||||
|
dstIP := generateRandomIPs(1)[0]
|
||||||
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
|
if sc.established {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic
|
||||||
|
func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
state string // "new", "established", "post_handshake" (TCP only)
|
||||||
|
setupFunc func(*Manager)
|
||||||
|
genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_tcp_new",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: TCP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_tcp_established",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate packets with ACK flag for established connection
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: TCP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_udp_new",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: UDP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_udp_established",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: UDP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_new",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_established",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate established TCP packets (ACK flag)
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_post_handshake",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "post_handshake",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate packets with PSH+ACK flags for data transfer
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP post-handshake data transfer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_udp_new",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: UDP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_udp_established",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: UDP established connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup scenario
|
||||||
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
|
// Use IPs outside WG range for routed network simulation
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
outbound, inbound := sc.genPackets(srcIP, dstIP)
|
||||||
|
|
||||||
|
// For stateful cases and established connections
|
||||||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
|
||||||
|
// For TCP post-handshake, simulate full handshake
|
||||||
|
if sc.state == "post_handshake" {
|
||||||
|
// SYN
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
// SYN-ACK
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
// ACK
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var scenarios = []struct {
|
||||||
|
name string
|
||||||
|
stateful bool // Whether conntrack is enabled
|
||||||
|
rules bool // Whether to add return traffic rules
|
||||||
|
routed bool // Whether to test routed network traffic
|
||||||
|
connCount int // Number of concurrent connections
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stateless_with_rules_100conns",
|
||||||
|
stateful: false,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Pure stateless with return traffic rules, 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateless_with_rules_1000conns",
|
||||||
|
stateful: false,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Pure stateless with return traffic rules, 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Pure stateful tracking without rules, 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Pure stateful tracking without rules, 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_rules_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Combined stateful + rules (current implementation), 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_rules_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Combined stateful + rules (current implementation), 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "routed_network_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: true,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Routed network traffic (non-WG), 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "routed_network_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: true,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Routed network traffic (non-WG), 1000 conns",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns
|
||||||
|
func BenchmarkLongLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
// Single rule to allow all return traffic from port 80
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create established connections
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
// Initial SYN
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
|
||||||
|
// SYN-ACK
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
|
||||||
|
// ACK
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare test packets simulating bidirectional traffic
|
||||||
|
inPackets := make([][]byte, sc.connCount)
|
||||||
|
outPackets := make([][]byte, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
// Server -> Client (inbound)
|
||||||
|
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
// Client -> Server (outbound)
|
||||||
|
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
connIdx := i % sc.connCount
|
||||||
|
|
||||||
|
// Simulate bidirectional traffic
|
||||||
|
// First outbound data
|
||||||
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
|
// Then inbound response - this is what we're actually measuring
|
||||||
|
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkShortLivedConnections tests performance with many short-lived connections
|
||||||
|
func BenchmarkShortLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
// Single rule to allow all return traffic from port 80
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create packet patterns for a complete HTTP-like short connection:
|
||||||
|
// 1. Initial handshake (SYN, SYN-ACK, ACK)
|
||||||
|
// 2. HTTP Request (PSH+ACK from client)
|
||||||
|
// 3. HTTP Response (PSH+ACK from server)
|
||||||
|
// 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK)
|
||||||
|
type connPackets struct {
|
||||||
|
syn []byte
|
||||||
|
synAck []byte
|
||||||
|
ack []byte
|
||||||
|
request []byte
|
||||||
|
response []byte
|
||||||
|
finClient []byte
|
||||||
|
ackServer []byte
|
||||||
|
finServer []byte
|
||||||
|
ackClient []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all possible connection patterns
|
||||||
|
patterns := make([]connPackets, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
patterns[i] = connPackets{
|
||||||
|
// Handshake
|
||||||
|
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||||
|
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||||
|
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
|
||||||
|
// Data transfer
|
||||||
|
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
|
||||||
|
// Connection teardown
|
||||||
|
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||||
|
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Each iteration creates a new short-lived connection
|
||||||
|
connIdx := i % sc.connCount
|
||||||
|
p := patterns[connIdx]
|
||||||
|
|
||||||
|
// Connection establishment
|
||||||
|
manager.processOutgoingHooks(p.syn)
|
||||||
|
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
|
// Data transfer
|
||||||
|
manager.processOutgoingHooks(p.request)
|
||||||
|
manager.dropFilter(p.response, manager.incomingRules)
|
||||||
|
|
||||||
|
// Connection teardown
|
||||||
|
manager.processOutgoingHooks(p.finClient)
|
||||||
|
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||||
|
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel
|
||||||
|
func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create established connections
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-generate test packets
|
||||||
|
inPackets := make([][]byte, sc.connCount)
|
||||||
|
outPackets := make([][]byte, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
// Each goroutine gets its own counter to distribute load
|
||||||
|
counter := 0
|
||||||
|
for pb.Next() {
|
||||||
|
connIdx := counter % sc.connCount
|
||||||
|
counter++
|
||||||
|
|
||||||
|
// Simulate bidirectional traffic
|
||||||
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
|
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel
|
||||||
|
func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
if sc.rules {
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs and pre-generate all packet patterns
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connPackets struct {
|
||||||
|
syn []byte
|
||||||
|
synAck []byte
|
||||||
|
ack []byte
|
||||||
|
request []byte
|
||||||
|
response []byte
|
||||||
|
finClient []byte
|
||||||
|
ackServer []byte
|
||||||
|
finServer []byte
|
||||||
|
ackClient []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
patterns := make([]connPackets, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
patterns[i] = connPackets{
|
||||||
|
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||||
|
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||||
|
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||||
|
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
counter := 0
|
||||||
|
for pb.Next() {
|
||||||
|
connIdx := counter % sc.connCount
|
||||||
|
counter++
|
||||||
|
p := patterns[connIdx]
|
||||||
|
|
||||||
|
// Full connection lifecycle
|
||||||
|
manager.processOutgoingHooks(p.syn)
|
||||||
|
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
|
manager.processOutgoingHooks(p.request)
|
||||||
|
manager.dropFilter(p.response, manager.incomingRules)
|
||||||
|
|
||||||
|
manager.processOutgoingHooks(p.finClient)
|
||||||
|
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||||
|
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||||
|
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set TCP flags
|
||||||
|
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||||
|
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||||
|
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||||
|
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||||
|
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||||
|
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -185,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager, err := Create(&IFaceMock{
|
||||||
incomingRules: map[string]RuleSet{},
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
outgoingRules: map[string]RuleSet{},
|
})
|
||||||
}
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
@ -313,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
t.Errorf("failed to set network layer for checksum: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload := gopacket.Payload([]byte("test"))
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
buf := gopacket.NewSerializeBuffer()
|
||||||
opts := gopacket.SerializeOptions{
|
opts := gopacket.SerializeOptions{
|
||||||
@ -325,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -348,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
@ -384,6 +389,88 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
manager.udpTracker.Close()
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCalled := false
|
||||||
|
hookID := manager.AddUDPPacketHook(
|
||||||
|
false,
|
||||||
|
net.ParseIP("100.10.0.100"),
|
||||||
|
53,
|
||||||
|
func([]byte) bool {
|
||||||
|
hookCalled = true
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NotEmpty(t, hookID)
|
||||||
|
|
||||||
|
// Create test UDP packet
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: net.ParseIP("100.10.0.1"),
|
||||||
|
DstIP: net.ParseIP("100.10.0.100"),
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: 51334,
|
||||||
|
DstPort: 53,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = udp.SetNetworkLayerForChecksum(ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test hook gets called
|
||||||
|
result := manager.processOutgoingHooks(buf.Bytes())
|
||||||
|
require.True(t, result)
|
||||||
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
|
// Test non-UDP packet is ignored
|
||||||
|
ipv4.Protocol = layers.IPProtocolTCP
|
||||||
|
buf = gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result = manager.processOutgoingHooks(buf.Bytes())
|
||||||
|
require.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
@ -418,3 +505,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set up packet parameters
|
||||||
|
srcIP := net.ParseIP("100.10.0.1")
|
||||||
|
dstIP := net.ParseIP("100.10.0.100")
|
||||||
|
srcPort := uint16(51334)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Create outbound packet
|
||||||
|
outboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
outboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = gopacket.SerializeLayers(outboundBuf, opts,
|
||||||
|
outboundIPv4,
|
||||||
|
outboundUDP,
|
||||||
|
gopacket.Payload("test"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Process outbound packet and verify connection tracking
|
||||||
|
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||||
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
|
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||||
|
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||||
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
|
// Create valid inbound response packet
|
||||||
|
inboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: dstIP, // Original destination is now source
|
||||||
|
DstIP: srcIP, // Original source is now destination
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
inboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
|
||||||
|
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
|
||||||
|
}
|
||||||
|
|
||||||
|
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(inboundBuf, opts,
|
||||||
|
inboundIPv4,
|
||||||
|
inboundUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Test roundtrip response handling over time
|
||||||
|
checkPoints := []struct {
|
||||||
|
sleep time.Duration
|
||||||
|
shouldAllow bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sleep: 0,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Immediate response should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 50 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response within timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 100 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response at half timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
|
||||||
|
sleep: 250 * time.Millisecond,
|
||||||
|
shouldAllow: false,
|
||||||
|
description: "Response after timeout should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cp := range checkPoints {
|
||||||
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
|
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||||
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
|
// If the connection should still be valid, verify it exists
|
||||||
|
if cp.shouldAllow {
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
require.True(t, exists, "Connection should still exist during valid window")
|
||||||
|
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||||
|
"LastSeen should be updated for valid responses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid response packets (while connection is expired)
|
||||||
|
invalidCases := []struct {
|
||||||
|
name string
|
||||||
|
modifyFunc func(*layers.IPv4, *layers.UDP)
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wrong source IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.SrcIP = net.ParseIP("100.10.0.101")
|
||||||
|
},
|
||||||
|
description: "Response from wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.DstIP = net.ParseIP("100.10.0.2")
|
||||||
|
},
|
||||||
|
description: "Response to wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong source port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.SrcPort = 54
|
||||||
|
},
|
||||||
|
description: "Response from wrong port should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.DstPort = 51335
|
||||||
|
},
|
||||||
|
description: "Response to wrong port should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new outbound connection for invalid tests
|
||||||
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||||
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
|
for _, tc := range invalidCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testIPv4 := *inboundIPv4
|
||||||
|
testUDP := *inboundUDP
|
||||||
|
|
||||||
|
tc.modifyFunc(&testIPv4, &testUDP)
|
||||||
|
|
||||||
|
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(testBuf, opts,
|
||||||
|
&testIPv4,
|
||||||
|
&testUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the invalid packet is dropped
|
||||||
|
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||||
|
require.True(t, drop, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||||
var networks []ice.NetworkType
|
var networks []ice.NetworkType
|
||||||
switch {
|
switch {
|
||||||
case addr.IP.To4() != nil:
|
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
|
||||||
|
|
||||||
case addr.IP.To16() != nil:
|
case addr.IP.To16() != nil:
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||||
|
|
||||||
|
case addr.IP.To4() != nil:
|
||||||
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||||
}
|
}
|
||||||
|
@ -27,14 +27,14 @@ import (
|
|||||||
type status int
|
type status int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultModuleDir = "/lib/modules"
|
unknown status = 1
|
||||||
unknown status = iota
|
unloaded status = 2
|
||||||
unloaded
|
unloading status = 3
|
||||||
unloading
|
loading status = 4
|
||||||
loading
|
live status = 5
|
||||||
live
|
inuse status = 6
|
||||||
inuse
|
defaultModuleDir = "/lib/modules"
|
||||||
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
|
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
|
||||||
)
|
)
|
||||||
|
|
||||||
type module struct {
|
type module struct {
|
||||||
|
@ -15,6 +15,10 @@ func IsEnabled() bool {
|
|||||||
|
|
||||||
func ListenAddr() string {
|
func ListenAddr() string {
|
||||||
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
|
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
|
||||||
|
if sPort == "" {
|
||||||
|
return listenAddr(DefaultSocks5Port)
|
||||||
|
}
|
||||||
|
|
||||||
port, err := strconv.Atoi(sPort)
|
port, err := strconv.Atoi(sPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
|
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
|
||||||
|
@ -46,6 +46,7 @@ type ConfigInput struct {
|
|||||||
ManagementURL string
|
ManagementURL string
|
||||||
AdminURL string
|
AdminURL string
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
@ -60,6 +61,11 @@ type ConfigInput struct {
|
|||||||
DNSRouteInterval *time.Duration
|
DNSRouteInterval *time.Duration
|
||||||
ClientCertPath string
|
ClientCertPath string
|
||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
|
DisableClientRoutes *bool
|
||||||
|
DisableServerRoutes *bool
|
||||||
|
DisableDNS *bool
|
||||||
|
DisableFirewall *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@ -77,6 +83,12 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
DisableDNS bool
|
||||||
|
DisableFirewall bool
|
||||||
|
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
@ -105,10 +117,10 @@ type Config struct {
|
|||||||
|
|
||||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
//Path to a certificate used for mTLS authentication
|
// Path to a certificate used for mTLS authentication
|
||||||
ClientCertPath string
|
ClientCertPath string
|
||||||
|
|
||||||
//Path to corresponding private key of ClientCertPath
|
// Path to corresponding private key of ClientCertPath
|
||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
@ -116,7 +128,7 @@ type Config struct {
|
|||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
func ReadConfig(configPath string) (*Config, error) {
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
if configFileIsExists(configPath) {
|
if fileExists(configPath) {
|
||||||
err := util.EnforcePermission(configPath)
|
err := util.EnforcePermission(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
@ -149,7 +161,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
|
|
||||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !fileExists(input.ConfigPath) {
|
||||||
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,13 +170,13 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !fileExists(input.ConfigPath) {
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
cfg, err := createNewConfig(input)
|
cfg, err := createNewConfig(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||||
return cfg, err
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,7 +197,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// WriteOutConfig write put the prepared config to the given path
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
func WriteOutConfig(path string, config *Config) error {
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
return util.WriteJson(path, config)
|
return util.WriteJson(context.Background(), path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
@ -215,7 +227,7 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updated {
|
if updated {
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,7 +413,46 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||||
updated = true
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
|
||||||
|
if *input.DisableClientRoutes {
|
||||||
|
log.Infof("disabling client routes")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling client routes")
|
||||||
|
}
|
||||||
|
config.DisableClientRoutes = *input.DisableClientRoutes
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
|
||||||
|
if *input.DisableServerRoutes {
|
||||||
|
log.Infof("disabling server routes")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling server routes")
|
||||||
|
}
|
||||||
|
config.DisableServerRoutes = *input.DisableServerRoutes
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
|
||||||
|
if *input.DisableDNS {
|
||||||
|
log.Infof("disabling DNS configuration")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling DNS configuration")
|
||||||
|
}
|
||||||
|
config.DisableDNS = *input.DisableDNS
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
|
||||||
|
if *input.DisableFirewall {
|
||||||
|
log.Infof("disabling firewall configuration")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling firewall configuration")
|
||||||
|
}
|
||||||
|
config.DisableFirewall = *input.DisableFirewall
|
||||||
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ClientCertKeyPath != "" {
|
if input.ClientCertKeyPath != "" {
|
||||||
@ -472,11 +523,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func configFileIsExists(path string) bool {
|
func fileExists(path string) bool {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return !os.IsNotExist(err)
|
return !os.IsNotExist(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createFile(path string) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||||
// The check is performed only for the NetBird's managed version.
|
// The check is performed only for the NetBird's managed version.
|
||||||
|
@ -40,6 +40,8 @@ type ConnectClient struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
engine *Engine
|
engine *Engine
|
||||||
engineMutex sync.Mutex
|
engineMutex sync.Mutex
|
||||||
|
|
||||||
|
persistNetworkMap bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectClient(
|
func NewConnectClient(
|
||||||
@ -89,6 +91,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
debug.SetGCPercent(5)
|
debug.SetGCPercent(5)
|
||||||
@ -97,6 +100,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil, nil)
|
||||||
}
|
}
|
||||||
@ -157,7 +161,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkManagementDisconnected(err)
|
||||||
c.statusRecorder.CleanLocalPeerState()
|
c.statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@ -231,6 +236,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||||
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
if len(relayURLs) > 0 {
|
if len(relayURLs) > 0 {
|
||||||
if token != nil {
|
if token != nil {
|
||||||
if err := relayManager.UpdateToken(token); err != nil {
|
if err := relayManager.UpdateToken(token); err != nil {
|
||||||
@ -241,9 +247,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||||
if err = relayManager.Serve(); err != nil {
|
if err = relayManager.Serve(); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
|
||||||
}
|
}
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
@ -258,7 +262,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||||
|
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(); err != nil {
|
if err := c.engine.Start(); err != nil {
|
||||||
@ -336,6 +340,19 @@ func (c *ConnectClient) Engine() *Engine {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Status returns the current client status
|
||||||
|
func (c *ConnectClient) Status() StatusType {
|
||||||
|
if c == nil {
|
||||||
|
return StatusIdle
|
||||||
|
}
|
||||||
|
status, err := CtxGetState(c.ctx).Status()
|
||||||
|
if err != nil {
|
||||||
|
return StatusIdle
|
||||||
|
}
|
||||||
|
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Stop() error {
|
func (c *ConnectClient) Stop() error {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -362,6 +379,21 @@ func (c *ConnectClient) isContextCancelled() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||||
|
// When enabled, the last received network map will be stored and can be retrieved
|
||||||
|
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||||
|
// network map will be cleared.
|
||||||
|
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
||||||
|
c.engineMutex.Lock()
|
||||||
|
c.persistNetworkMap = enabled
|
||||||
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine != nil {
|
||||||
|
engine.SetNetworkMapPersistence(enabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||||
nm := false
|
nm := false
|
||||||
@ -383,6 +415,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
|
DisableServerRoutes: config.DisableServerRoutes,
|
||||||
|
DisableDNS: config.DisableDNS,
|
||||||
|
DisableFirewall: config.DisableFirewall,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
18
client/internal/dns/consts.go
Normal file
18
client/internal/dns/consts.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fileUncleanShutdownResolvConfLocation string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE")
|
||||||
|
if fileUncleanShutdownResolvConfLocation == "" {
|
||||||
|
fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf")
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
const (
|
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
|
||||||
)
|
|
@ -1,7 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
const (
|
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
|
||||||
)
|
|
225
client/internal/dns/handler_chain.go
Normal file
225
client/internal/dns/handler_chain.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityDNSRoute = 100
|
||||||
|
PriorityMatchDomain = 50
|
||||||
|
PriorityDefault = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type SubdomainMatcher interface {
|
||||||
|
dns.Handler
|
||||||
|
MatchSubdomains() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandlerEntry struct {
|
||||||
|
Handler dns.Handler
|
||||||
|
Priority int
|
||||||
|
Pattern string
|
||||||
|
OrigPattern string
|
||||||
|
IsWildcard bool
|
||||||
|
StopHandler handlerWithStop
|
||||||
|
MatchSubdomains bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerChain represents a prioritized chain of DNS handlers
|
||||||
|
type HandlerChain struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
handlers []HandlerEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
|
type ResponseWriterChain struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
origPattern string
|
||||||
|
shouldContinue bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||||
|
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||||
|
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||||
|
w.shouldContinue = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandlerChain() *HandlerChain {
|
||||||
|
return &HandlerChain{
|
||||||
|
handlers: make([]HandlerEntry, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||||
|
func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||||
|
return w.origPattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
|
origPattern := pattern
|
||||||
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
|
if isWildcard {
|
||||||
|
pattern = pattern[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
|
if c.handlers[i].StopHandler != nil {
|
||||||
|
c.handlers[i].StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if handler implements SubdomainMatcher interface
|
||||||
|
matchSubdomains := false
|
||||||
|
if matcher, ok := handler.(SubdomainMatcher); ok {
|
||||||
|
matchSubdomains = matcher.MatchSubdomains()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
|
pattern, origPattern, isWildcard, matchSubdomains, priority)
|
||||||
|
|
||||||
|
entry := HandlerEntry{
|
||||||
|
Handler: handler,
|
||||||
|
Priority: priority,
|
||||||
|
Pattern: pattern,
|
||||||
|
OrigPattern: origPattern,
|
||||||
|
IsWildcard: isWildcard,
|
||||||
|
StopHandler: stopHandler,
|
||||||
|
MatchSubdomains: matchSubdomains,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert handler in priority order
|
||||||
|
pos := 0
|
||||||
|
for i, h := range c.handlers {
|
||||||
|
if h.Priority < priority {
|
||||||
|
pos = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pos = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
|
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
entry := c.handlers[i]
|
||||||
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
|
if entry.StopHandler != nil {
|
||||||
|
entry.StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||||
|
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
|
for _, entry := range c.handlers {
|
||||||
|
if strings.EqualFold(entry.Pattern, pattern) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
|
log.Tracef("handling DNS request for domain=%s", qname)
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
handlers := slices.Clone(c.handlers)
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
|
for _, h := range handlers {
|
||||||
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||||
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try handlers in priority order
|
||||||
|
for _, entry := range handlers {
|
||||||
|
var matched bool
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
matched = true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
|
default:
|
||||||
|
// For non-wildcard patterns:
|
||||||
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
// Otherwise require exact match
|
||||||
|
if entry.MatchSubdomains {
|
||||||
|
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
} else {
|
||||||
|
matched = strings.EqualFold(qname, entry.Pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||||
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||||
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||||
|
|
||||||
|
chainWriter := &ResponseWriterChain{
|
||||||
|
ResponseWriter: w,
|
||||||
|
origPattern: entry.OrigPattern,
|
||||||
|
}
|
||||||
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
|
// If handler wants to continue, try next handler
|
||||||
|
if chainWriter.shouldContinue {
|
||||||
|
log.Tracef("handler requested continue to next handler")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No handler matched or all handlers passed
|
||||||
|
log.Tracef("no handler found for domain=%s", qname)
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
679
client/internal/dns/handler_chain_test.go
Normal file
679
client/internal/dns/handler_chain_test.go
Normal file
@ -0,0 +1,679 @@
|
|||||||
|
package dns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create mock handlers for different priorities
|
||||||
|
defaultHandler := &nbdns.MockHandler{}
|
||||||
|
matchDomainHandler := &nbdns.MockHandler{}
|
||||||
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
|
// Setup handlers with different priorities
|
||||||
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Create test writer
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations - only highest priority handler should be called
|
||||||
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all expectations were met
|
||||||
|
dnsRouteHandler.AssertExpectations(t)
|
||||||
|
matchDomainHandler.AssertExpectations(t)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||||
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlerDomain string
|
||||||
|
queryDomain string
|
||||||
|
isWildcard bool
|
||||||
|
matchSubdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard and MatchSubdomains true",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard and MatchSubdomains false",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on apex",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone match",
|
||||||
|
handlerDomain: ".",
|
||||||
|
queryDomain: "anything.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match different domain",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.org.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handler dns.Handler
|
||||||
|
|
||||||
|
if tt.matchSubdomains {
|
||||||
|
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
handler = mockSubHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
handler = mockHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := tt.handlerDomain
|
||||||
|
if tt.isWildcard {
|
||||||
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if h, ok := handler.(*nbdns.MockHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||||
|
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
queryDomain string
|
||||||
|
expectedCalls int
|
||||||
|
expectedHandler int // index of the handler that should be called
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard and exact same priority - exact should win",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // exact match handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "higher priority wildcard over lower priority exact",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards different priorities",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with mix of patterns",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "sub.test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone with specific domain",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority specific domain should win over root
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handlers []*nbdns.MockHandler
|
||||||
|
|
||||||
|
// Setup handlers and expectations
|
||||||
|
for i := range tt.handlers {
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
handlers = append(handlers, handler)
|
||||||
|
|
||||||
|
// Set expectation based on whether this handler should be called
|
||||||
|
if i == tt.expectedHandler {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||||
|
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create handlers
|
||||||
|
handler1 := &nbdns.MockHandler{}
|
||||||
|
handler2 := &nbdns.MockHandler{}
|
||||||
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
|
// Add handlers in priority order
|
||||||
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Setup mock responses to simulate chain continuation
|
||||||
|
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// First handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true // Signal to continue
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Second handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Last handler responds normally
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all handlers were called in order
|
||||||
|
handler1.AssertExpectations(t)
|
||||||
|
handler2.AssertExpectations(t)
|
||||||
|
handler3.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (m *mockResponseWriter) Close() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (m *mockResponseWriter) Hijack() {}
|
||||||
|
|
||||||
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ops []struct {
|
||||||
|
action string // "add" or "remove"
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls map[int]bool // map[priority]shouldBeCalled
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "remove high priority keeps lower priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove lower priority keeps high priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: true,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove all handlers in order",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"add", "example.com.", nbdns.PriorityDefault},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
nbdns.PriorityDefault: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlers := make(map[int]*nbdns.MockHandler)
|
||||||
|
|
||||||
|
// Execute operations
|
||||||
|
for _, op := range tt.ops {
|
||||||
|
if op.action == "add" {
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
handlers[op.priority] = handler
|
||||||
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||||
|
} else {
|
||||||
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
for priority, handler := range handlers {
|
||||||
|
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler exists check
|
||||||
|
for priority, shouldExist := range tt.expectedCalls {
|
||||||
|
if shouldExist {
|
||||||
|
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||||
|
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
testDomain := "example.com."
|
||||||
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
|
// Create handlers with MatchSubdomains enabled
|
||||||
|
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
|
||||||
|
// Create test request that will be reused
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Add handlers in mixed order
|
||||||
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
// Test 1: Initial state with all three handlers
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Highest priority handler (routeHandler) should be called
|
||||||
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 2: Remove highest priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now middle priority handler (matchHandler) should be called
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 3: Remove middle priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 4: Remove last handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
assert.False(t, chain.HasHandlers(testDomain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
addHandlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "case insensitive exact match",
|
||||||
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive wildcard match",
|
||||||
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "sub.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different case same domain",
|
||||||
|
scenario: "second handler should replace first despite case difference",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "ExAmPlE.cOm.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain matching case insensitive",
|
||||||
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||||
|
},
|
||||||
|
query: "SUB.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone case insensitive",
|
||||||
|
scenario: "root zone handler should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{".", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different priority",
|
||||||
|
scenario: "should call higher priority handler despite case differences",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||||
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||||
|
|
||||||
|
// Add handlers according to test case
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
var handler dns.Handler
|
||||||
|
pattern := h.pattern // capture pattern for closure
|
||||||
|
|
||||||
|
if h.subdomains {
|
||||||
|
subHandler := &nbdns.MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
if h.shouldMatch {
|
||||||
|
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = subHandler
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
if h.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = mockHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||||
|
|
||||||
|
// Verify each handler was called exactly as expected
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
wasCalled := handlerCalls[h.pattern]
|
||||||
|
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||||
|
"Handler for pattern %q was %s when it should%s have been",
|
||||||
|
h.pattern,
|
||||||
|
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||||
|
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify total number of calls
|
||||||
|
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||||
|
"Wrong number of total handler calls")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type noopHostConfigurator struct{}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) restoreHostDNS() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) supportCustomPort() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@ -17,12 +17,24 @@ type localResolver struct {
|
|||||||
records sync.Map
|
records sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *localResolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
func (d *localResolver) stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the local resolver
|
||||||
|
func (d *localResolver) String() string {
|
||||||
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v", r.Question[0])
|
if len(r.Question) > 0 {
|
||||||
|
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
@ -3,14 +3,30 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
type MockServer struct {
|
type MockServer struct {
|
||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
|
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||||
|
DeregisterHandlerFunc func([]string, int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
if m.RegisterHandlerFunc != nil {
|
||||||
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
if m.DeregisterHandlerFunc != nil {
|
||||||
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Initialize mock implementation of Initialize from Server interface
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
@ -31,6 +30,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
|
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||||
|
DeregisterHandler(domains []string, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@ -46,15 +47,18 @@ type registeredHandlerMap map[string]handlerWithStop
|
|||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
|
disableSys bool
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
|
handlerPriorities map[string]int
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
handlerChain *HandlerChain
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@ -75,12 +79,20 @@ type handlerWithStop interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type muxUpdate struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
func NewDefaultServer(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
customAddress string,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
|
disableSys bool,
|
||||||
|
) (*DefaultServer, error) {
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@ -97,7 +109,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
@ -108,9 +120,10 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
config nbdns.Config,
|
config nbdns.Config,
|
||||||
listener listener.NetworkChangeListener,
|
listener listener.NetworkChangeListener,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@ -127,19 +140,30 @@ func NewDefaultServerIos(
|
|||||||
wgInterface WGIface,
|
wgInterface WGIface,
|
||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
func newDefaultServer(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
dnsService service,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
|
disableSys bool,
|
||||||
|
) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
service: dnsService,
|
disableSys: disableSys,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
service: dnsService,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@ -152,6 +176,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.registerHandler(domains, handler, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||||
|
s.handlerPriorities[domain] = priority
|
||||||
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.deregisterHandler(domains, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
|
// Only deregister from service if no handlers remain
|
||||||
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@ -169,6 +238,13 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.stateManager.RegisterState(&ShutdownState{})
|
s.stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
|
if s.disableSys {
|
||||||
|
log.Info("system DNS is disabled, not setting up host manager")
|
||||||
|
s.hostManager = &noopHostConfigurator{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
s.hostManager, err = s.initialize()
|
s.hostManager, err = s.initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initialize: %w", err)
|
return fmt.Errorf("initialize: %w", err)
|
||||||
@ -217,47 +293,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
|||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
select {
|
if s.ctx.Err() != nil {
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
log.Infof("not updating DNS server as context is closed")
|
||||||
return s.ctx.Err()
|
return s.ctx.Err()
|
||||||
default:
|
}
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager == nil {
|
if serial < s.updateSerial {
|
||||||
return fmt.Errorf("dns service is not initialized yet")
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
}
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
s.mux.Lock()
|
||||||
ZeroNil: true,
|
defer s.mux.Unlock()
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
if s.hostManager == nil {
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
return fmt.Errorf("dns service is not initialized yet")
|
||||||
s.updateSerial = serial
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
return fmt.Errorf("apply configuration: %w", err)
|
ZeroNil: true,
|
||||||
}
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
s.updateSerial = serial
|
s.updateSerial = serial
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return fmt.Errorf("apply configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) SearchDomains() []string {
|
func (s *DefaultServer) SearchDomains() []string {
|
||||||
@ -323,12 +399,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist dns state right away
|
go func() {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
// persist dns state right away
|
||||||
defer cancel()
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
log.Errorf("Failed to persist dns state: %v", err)
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
if s.searchDomainNotifier != nil {
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
@ -344,14 +420,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
@ -413,8 +489,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityDefault,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -430,8 +507,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
}
|
}
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -441,12 +519,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
handlersByPriority := make(map[string]int)
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
|
// First register new handlers
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
|
handlersByPriority[update.domain] = update.priority
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
@ -456,6 +538,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Then deregister old handlers not in the update
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
@ -464,12 +547,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
} else {
|
} else {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
s.service.DeregisterMux(key)
|
// Deregister with the priority that was used to register
|
||||||
|
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||||
|
s.deregisterHandler([]string{key}, oldPriority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
s.handlerPriorities = handlersByPriority
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@ -518,13 +605,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.service.DeregisterMux(nbdns.RootZone)
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.service.DeregisterMux(item.Domain)
|
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -533,12 +620,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist dns state right away
|
go func() {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
defer cancel()
|
l.Errorf("Failed to persist dns state: %v", err)
|
||||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
}
|
||||||
l.Errorf("Failed to persist dns state: %v", err)
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
@ -556,7 +642,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.service.RegisterMux(domain, handler)
|
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@ -564,10 +650,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
if s.hostManager != nil {
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
@ -595,7 +684,8 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
}
|
}
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
|
@ -11,7 +11,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@ -292,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -401,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@ -496,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
hostManager: hostManager,
|
handlerChain: NewHandlerChain(),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
|
hostManager: hostManager,
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
{false, "domain0", false},
|
{false, "domain0", false},
|
||||||
@ -629,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
|
|
||||||
var dnsList []string
|
var dnsList []string
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@ -653,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@ -745,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@ -782,7 +786,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
Port: 53,
|
Port: 53,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Domains: []string{"customdomain.com"},
|
Domains: []string{"google.com"},
|
||||||
Primary: false,
|
Primary: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -804,7 +808,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
if ips[0] != zoneRecords[0].RData {
|
if ips[0] != zoneRecords[0].RData {
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
}
|
}
|
||||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to resolve: %s", err)
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockHandler implements dns.Handler interface for testing
|
||||||
|
type MockHandler struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m.Called(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSubdomainHandler struct {
|
||||||
|
MockHandler
|
||||||
|
Subdomains bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
||||||
|
return m.Subdomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||||
|
chain := NewHandlerChain()
|
||||||
|
|
||||||
|
dnsRouteHandler := &MockHandler{}
|
||||||
|
upstreamHandler := &MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expectedHandler dns.Handler
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain with dns route handler",
|
||||||
|
query: "example.com.",
|
||||||
|
expectedHandler: dnsRouteHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain should use upstream handler",
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should use upstream handler",
|
||||||
|
query: "deep.sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tc.query, dns.TypeA)
|
||||||
|
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset mocks
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the upstream resolver
|
||||||
|
func (u *upstreamResolverBase) String() string {
|
||||||
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
157
client/internal/dnsfwd/forwarder.go
Normal file
157
client/internal/dnsfwd/forwarder.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
|
|
||||||
|
type DNSForwarder struct {
|
||||||
|
listenAddress string
|
||||||
|
ttl uint32
|
||||||
|
domains []string
|
||||||
|
|
||||||
|
dnsServer *dns.Server
|
||||||
|
mux *dns.ServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
||||||
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
|
return &DNSForwarder{
|
||||||
|
listenAddress: listenAddress,
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Listen(domains []string) error {
|
||||||
|
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Addr: f.listenAddress,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
f.dnsServer = dnsServer
|
||||||
|
f.mux = mux
|
||||||
|
|
||||||
|
f.UpdateDomains(domains)
|
||||||
|
|
||||||
|
return dnsServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||||
|
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||||
|
|
||||||
|
for _, d := range f.domains {
|
||||||
|
f.mux.HandleRemove(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
newDomains := filterDomains(domains)
|
||||||
|
for _, d := range newDomains {
|
||||||
|
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||||
|
}
|
||||||
|
f.domains = newDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
|
if f.dnsServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.dnsServer.ShutdownContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
if len(query.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||||
|
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
||||||
|
|
||||||
|
question := query.Question[0]
|
||||||
|
domain := question.Name
|
||||||
|
|
||||||
|
resp := query.SetReply(query)
|
||||||
|
|
||||||
|
ips, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &dnsErr):
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
// Pass through NXDOMAIN
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
var respRecord dns.RR
|
||||||
|
if ip.To4() == nil {
|
||||||
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||||
|
rr := dns.AAAA{
|
||||||
|
AAAA: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
} else {
|
||||||
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||||
|
rr := dns.A{
|
||||||
|
A: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
}
|
||||||
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterDomains returns a list of normalized domains
|
||||||
|
func filterDomains(domains []string) []string {
|
||||||
|
newDomains := make([]string, 0, len(domains))
|
||||||
|
for _, d := range domains {
|
||||||
|
if d == "" {
|
||||||
|
log.Warn("empty domain in DNS forwarder")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newDomains = append(newDomains, nbdns.NormalizeZone(d))
|
||||||
|
}
|
||||||
|
return newDomains
|
||||||
|
}
|
111
client/internal/dnsfwd/manager.go
Normal file
111
client/internal/dnsfwd/manager.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||||
|
ListenPort = 5353
|
||||||
|
dnsTTL = 60 //seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
firewall firewall.Manager
|
||||||
|
|
||||||
|
fwRules []firewall.Rule
|
||||||
|
dnsForwarder *DNSForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(fw firewall.Manager) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
firewall: fw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(domains []string) error {
|
||||||
|
log.Infof("starting DNS forwarder")
|
||||||
|
if m.dnsForwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.allowDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
||||||
|
go func() {
|
||||||
|
if err := m.dnsForwarder.Listen(domains); err != nil {
|
||||||
|
// todo handle close error if it is exists
|
||||||
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) UpdateDomains(domains []string) {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var mErr *multierror.Error
|
||||||
|
if err := m.dropDNSFirewall(); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.dnsForwarder.Close(ctx); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) allowDNSFirewall() error {
|
||||||
|
dport := &firewall.Port{
|
||||||
|
IsRange: false,
|
||||||
|
Values: []int{ListenPort},
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.firewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.fwRules = dnsRules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) dropDNSFirewall() error {
|
||||||
|
var mErr *multierror.Error
|
||||||
|
for _, rule := range h.fwRules {
|
||||||
|
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.fwRules = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
@ -4,13 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@ -28,16 +29,18 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@ -61,6 +64,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
|
connInitLimit = 200
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
@ -104,6 +108,11 @@ type EngineConfig struct {
|
|||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
DisableDNS bool
|
||||||
|
DisableFirewall bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@ -114,7 +123,7 @@ type Engine struct {
|
|||||||
// mgmClient is a Management Service client
|
// mgmClient is a Management Service client
|
||||||
mgmClient mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerStore *peerstore.Store
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
beforePeerHook nbnet.AddHookFunc
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
afterPeerHook nbnet.RemoveHookFunc
|
||||||
@ -134,10 +143,6 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
stunTurn atomic.Value
|
stunTurn atomic.Value
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
|
||||||
clientRoutes route.HAMap
|
|
||||||
clientRoutesMu sync.RWMutex
|
|
||||||
|
|
||||||
clientCtx context.Context
|
clientCtx context.Context
|
||||||
clientCancel context.CancelFunc
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
@ -158,9 +163,10 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
firewall manager.Manager
|
firewall manager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@ -171,7 +177,12 @@ type Engine struct {
|
|||||||
|
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
|
// Network map persistence
|
||||||
|
persistNetworkMap bool
|
||||||
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@ -226,7 +237,7 @@ func NewEngineWithProbes(
|
|||||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
peerConns: make(map[string]*peer.Conn),
|
peerStore: peerstore.NewConnStore(),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
mobileDep: mobileDep,
|
mobileDep: mobileDep,
|
||||||
@ -237,6 +248,18 @@ func NewEngineWithProbes(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
probes: probes,
|
probes: probes,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
if !fileExists(mobileDep.StateFilePath) {
|
||||||
|
err := createFile(mobileDep.StateFilePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create state file: %v", err)
|
||||||
|
// we are not exiting as we can run without the state manager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
|
||||||
}
|
}
|
||||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
@ -267,19 +290,26 @@ func (e *Engine) Stop() error {
|
|||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsForwardMgr != nil {
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = nil
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@ -297,7 +327,7 @@ func (e *Engine) Stop() error {
|
|||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -349,8 +379,21 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
Context: e.ctx,
|
||||||
|
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
|
DNSRouteInterval: e.config.DNSRouteInterval,
|
||||||
|
WGInterface: e.wgInterface,
|
||||||
|
StatusRecorder: e.statusRecorder,
|
||||||
|
RelayManager: e.relayManager,
|
||||||
|
InitialRoutes: initialRoutes,
|
||||||
|
StateManager: e.stateManager,
|
||||||
|
DNSServer: dnsServer,
|
||||||
|
PeerStore: e.peerStore,
|
||||||
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
|
})
|
||||||
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
} else {
|
||||||
@ -367,17 +410,8 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
if err := e.createFirewall(); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
|
|
||||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.udpMux, err = e.wgInterface.Up()
|
e.udpMux, err = e.wgInterface.Up()
|
||||||
@ -419,6 +453,61 @@ func (e *Engine) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) createFirewall() error {
|
||||||
|
if e.config.DisableFirewall {
|
||||||
|
log.Infof("firewall is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||||
|
if err != nil || e.firewall == nil {
|
||||||
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.initFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) initFirewall() error {
|
||||||
|
if e.firewall.IsServerRouteSupported() {
|
||||||
|
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||||
|
e.close()
|
||||||
|
return fmt.Errorf("enable server router: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rpManager == nil || !e.config.RosenpassEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
|
port := manager.Port{Values: []int{rosenpassPort}}
|
||||||
|
|
||||||
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
net.IP{0, 0, 0, 0},
|
||||||
|
manager.ProtocolUDP,
|
||||||
|
nil,
|
||||||
|
&port,
|
||||||
|
manager.RuleDirectionIN,
|
||||||
|
manager.ActionAccept,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
); err != nil {
|
||||||
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
@ -427,8 +516,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
|
||||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if allowedIPs != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -459,17 +548,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
currentPeers := make([]string, 0, len(e.peerConns))
|
|
||||||
for p := range e.peerConns {
|
|
||||||
currentPeers = append(currentPeers, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
newPeers := make([]string, 0, len(peersUpdate))
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
newPeers = append(newPeers, p.GetWgPubKey())
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
toRemove := util.SliceDiff(currentPeers, newPeers)
|
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||||
|
|
||||||
for _, p := range toRemove {
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
@ -483,7 +567,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
|
|
||||||
func (e *Engine) removeAllPeers() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
log.Debugf("removing all peer connections")
|
log.Debugf("removing all peer connections")
|
||||||
for p := range e.peerConns {
|
for _, p := range e.peerStore.PeersPubKey() {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -507,9 +591,8 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, exists := e.peerConns[peerKey]
|
conn, exists := e.peerStore.Remove(peerKey)
|
||||||
if exists {
|
if exists {
|
||||||
delete(e.peerConns, peerKey)
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -538,6 +621,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
|
|
||||||
relayMsg := wCfg.GetRelay()
|
relayMsg := wCfg.GetRelay()
|
||||||
if relayMsg != nil {
|
if relayMsg != nil {
|
||||||
|
// when we receive token we expect valid address list too
|
||||||
c := &auth.Token{
|
c := &auth.Token{
|
||||||
Payload: relayMsg.GetTokenPayload(),
|
Payload: relayMsg.GetTokenPayload(),
|
||||||
Signature: relayMsg.GetTokenSignature(),
|
Signature: relayMsg.GetTokenSignature(),
|
||||||
@ -546,9 +630,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
log.Errorf("failed to update relay token: %v", err)
|
log.Errorf("failed to update relay token: %v", err)
|
||||||
return fmt.Errorf("update relay token: %w", err)
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
||||||
|
|
||||||
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
|
_ = e.relayManager.Serve()
|
||||||
|
} else {
|
||||||
|
e.relayManager.UpdateServerURLs(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update relay address in the relay manager
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,13 +647,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.GetNetworkMap() != nil {
|
nm := update.GetNetworkMap()
|
||||||
// only apply new changes and ignore old ones
|
if nm == nil {
|
||||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store network map if persistence is enabled
|
||||||
|
if e.persistNetworkMap {
|
||||||
|
e.latestNetworkMap = nm
|
||||||
|
log.Debugf("network map persisted with serial %d", nm.GetSerial())
|
||||||
|
}
|
||||||
|
|
||||||
|
// only apply new changes and ignore old ones
|
||||||
|
if err := e.updateNetworkMap(nm); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -641,6 +741,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wireguard interface is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address().String()
|
oldAddr := e.wgInterface.Address().String()
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||||
@ -659,12 +763,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
IP: e.config.WgAddr,
|
state.IP = e.config.WgAddr
|
||||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||||
FQDN: conf.GetFqdn(),
|
state.FQDN = conf.GetFqdn()
|
||||||
})
|
|
||||||
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -732,7 +837,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||||
|
|
||||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||||
if networkMap.GetPeerConfig() != nil {
|
if networkMap.GetPeerConfig() != nil {
|
||||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||||
@ -752,20 +856,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
// DNS forwarder
|
||||||
if protoRoutes == nil {
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
protoRoutes = []*mgmProto.Route{}
|
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
}
|
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
|
||||||
|
|
||||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
if err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = clientRoutes
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
@ -813,8 +913,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -827,7 +926,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||||
|
if networkMap.PeerConfig != nil {
|
||||||
|
return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
routes := make([]*route.Route, 0)
|
routes := make([]*route.Route, 0)
|
||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@ -838,6 +948,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@ -854,6 +965,23 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsRoutes []string
|
||||||
|
for _, protoRoute := range protoRoutes {
|
||||||
|
if len(protoRoute.Domains) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if protoRoute.Peer == myPubKey {
|
||||||
|
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dnsRoutes
|
||||||
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||||
dnsUpdate := nbdns.Config{
|
dnsUpdate := nbdns.Config{
|
||||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||||
@ -928,12 +1056,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||||
peerKey := peerConfig.GetWgPubKey()
|
peerKey := peerConfig.GetWgPubKey()
|
||||||
peerIPs := peerConfig.GetAllowedIps()
|
peerIPs := peerConfig.GetAllowedIps()
|
||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
|
||||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
|
||||||
|
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
@ -1001,7 +1133,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
|
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1022,8 +1154,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn := e.peerConns[msg.Key]
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if conn == nil {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1081,7 +1213,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1239,6 +1371,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return nil, e.dnsServer, nil
|
return nil, e.dnsServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "android":
|
case "android":
|
||||||
routes, dnsConfig, err := e.readInitialSettings()
|
routes, dnsConfig, err := e.readInitialSettings()
|
||||||
@ -1252,14 +1385,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
*dnsConfig,
|
*dnsConfig,
|
||||||
e.mobileDep.NetworkChangeListener,
|
e.mobileDep.NetworkChangeListener,
|
||||||
e.statusRecorder,
|
e.statusRecorder,
|
||||||
|
e.config.DisableDNS,
|
||||||
)
|
)
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
return routes, dnsServer, nil
|
return routes, dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -1268,26 +1404,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
|
||||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
return maps.Clone(e.clientRoutes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
|
||||||
for id, v := range e.clientRoutes {
|
|
||||||
routes[id.NetID()] = v
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRouteManager returns the route manager
|
// GetRouteManager returns the route manager
|
||||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||||
return e.routeManager
|
return e.routeManager
|
||||||
@ -1372,9 +1488,8 @@ func (e *Engine) receiveProbeEvents() {
|
|||||||
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
||||||
log.Debug("received wg probe request")
|
log.Debug("received wg probe request")
|
||||||
|
|
||||||
for _, peer := range e.peerConns {
|
for _, key := range e.peerStore.PeersPubKey() {
|
||||||
key := peer.GetKey()
|
wgStats, err := e.wgInterface.GetStats(key)
|
||||||
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||||
}
|
}
|
||||||
@ -1451,7 +1566,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||||
var vpnRoutes []netip.Prefix
|
var vpnRoutes []netip.Prefix
|
||||||
for _, routes := range e.GetClientRoutes() {
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
if len(routes) > 0 && routes[0] != nil {
|
if len(routes) > 0 && routes[0] != nil {
|
||||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||||
}
|
}
|
||||||
@ -1479,8 +1594,93 @@ func (e *Engine) stopDNSServer() {
|
|||||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
|
func (e *Engine) SetNetworkMapPersistence(enabled bool) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if enabled == e.persistNetworkMap {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.persistNetworkMap = enabled
|
||||||
|
log.Debugf("Network map persistence is set to %t", enabled)
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
e.latestNetworkMap = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatestNetworkMap returns the stored network map if persistence is enabled
|
||||||
|
func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if !e.persistNetworkMap {
|
||||||
|
return nil, errors.New("network map persistence is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.latestNetworkMap == nil {
|
||||||
|
//nolint:nilnil
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Retrieving latest network map with size %d bytes", proto.Size(e.latestNetworkMap))
|
||||||
|
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
|
||||||
|
if !ok {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to clone network map")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||||
|
if !enabled {
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
log.Infof("enable domain router service for domains: %v", domains)
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("update domain router service for domains: %v", domains)
|
||||||
|
e.dnsForwardMgr.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
} else if e.dnsForwardMgr != nil {
|
||||||
|
log.Infof("disable domain router service")
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
|
for _, check := range checks {
|
||||||
|
sort.Slice(check.Files, func(i, j int) bool {
|
||||||
|
return check.Files[i] < check.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, oCheck := range oChecks {
|
||||||
|
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||||
|
return oCheck.Files[i] < oCheck.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
|
@ -39,6 +39,8 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -245,12 +247,22 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
nil)
|
nil)
|
||||||
|
|
||||||
wgIface := &iface.MockWGIface{
|
wgIface := &iface.MockWGIface{
|
||||||
|
NameFunc: func() string { return "utun102" },
|
||||||
RemovePeerFunc: func(peerKey string) error {
|
RemovePeerFunc: func(peerKey string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
|
Context: ctx,
|
||||||
|
PublicKey: key.PublicKey().String(),
|
||||||
|
DNSRouteInterval: time.Minute,
|
||||||
|
WGInterface: engine.wgInterface,
|
||||||
|
StatusRecorder: engine.statusRecorder,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
})
|
||||||
|
_, _, err = engine.routeManager.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
@ -388,8 +400,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engine.peerConns) != c.expectedLen {
|
if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
|
||||||
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns))
|
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if engine.networkSerial != c.expectedSerial {
|
if engine.networkSerial != c.expectedSerial {
|
||||||
@ -397,7 +409,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range c.expectedPeers {
|
for _, p := range c.expectedPeers {
|
||||||
conn, ok := engine.peerConns[p.GetWgPubKey()]
|
conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||||
}
|
}
|
||||||
@ -622,10 +634,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
input.inputSerial = updateSerial
|
input.inputSerial = updateSerial
|
||||||
input.inputRoutes = newRoutes
|
input.inputRoutes = newRoutes
|
||||||
return nil, nil, testCase.inputErr
|
return testCase.inputErr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -798,8 +810,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1006,6 +1018,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CheckFilesEqual(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputChecks1 []*mgmtProto.Checks
|
||||||
|
inputChecks2 []*mgmtProto.Checks
|
||||||
|
expectedBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal Files In Equal Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equal Files In Reverse Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile2",
|
||||||
|
"testfile1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unequal Files Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Compared With Empty Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||||
|
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1100,7 +1205,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@ -1122,7 +1227,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@ -1141,7 +1246,8 @@ func getConnectedPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
i := 0
|
i := 0
|
||||||
for _, conn := range e.peerConns {
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
if conn.Status() == peer.StatusConnected {
|
if conn.Status() == peer.StatusConnected {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@ -1153,5 +1259,5 @@ func getPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
return len(e.peerConns)
|
return len(e.peerStore.PeersPubKey())
|
||||||
}
|
}
|
||||||
|
@ -19,4 +19,5 @@ type MobileDependency struct {
|
|||||||
// iOS only
|
// iOS only
|
||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
|
StateFilePath string
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnPriority int
|
type ConnPriority int
|
||||||
@ -83,7 +84,6 @@ type Conn struct {
|
|||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
allowedIP net.IP
|
allowedIP net.IP
|
||||||
allowedNet string
|
|
||||||
handshaker *Handshaker
|
handshaker *Handshaker
|
||||||
|
|
||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
@ -105,13 +105,14 @@ type Conn struct {
|
|||||||
wgProxyICE wgproxy.Proxy
|
wgProxyICE wgproxy.Proxy
|
||||||
wgProxyRelay wgproxy.Proxy
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
|
semaphore *semaphoregroup.SemaphoreGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
|
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
||||||
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -129,9 +130,9 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
allowedIP: allowedIP,
|
allowedIP: allowedIP,
|
||||||
allowedNet: allowedNet.String(),
|
|
||||||
statusRelay: NewAtomicConnStatus(),
|
statusRelay: NewAtomicConnStatus(),
|
||||||
statusICE: NewAtomicConnStatus(),
|
statusICE: NewAtomicConnStatus(),
|
||||||
|
semaphore: semaphore,
|
||||||
}
|
}
|
||||||
|
|
||||||
rFns := WorkerRelayCallbacks{
|
rFns := WorkerRelayCallbacks{
|
||||||
@ -171,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open() {
|
func (conn *Conn) Open() {
|
||||||
|
conn.semaphore.Add(conn.ctx)
|
||||||
conn.log.Debugf("open connection to peer")
|
conn.log.Debugf("open connection to peer")
|
||||||
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
@ -193,6 +195,7 @@ func (conn *Conn) Open() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||||
|
defer conn.semaphore.Done(conn.ctx)
|
||||||
conn.waitInitialRandomSleepTime(ctx)
|
conn.waitInitialRandomSleepTime(ctx)
|
||||||
|
|
||||||
err := conn.handshaker.sendOffer()
|
err := conn.handshaker.sendOffer()
|
||||||
@ -594,14 +597,13 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if conn.onConnected != nil {
|
if conn.onConnected != nil {
|
||||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
|
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||||
minWait := 100
|
maxWait := 300
|
||||||
maxWait := 800
|
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
|
||||||
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
|
|
||||||
|
|
||||||
timeout := time.NewTimer(duration)
|
timeout := time.NewTimer(duration)
|
||||||
defer timeout.Stop()
|
defer timeout.Stop()
|
||||||
@ -745,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
|||||||
conn.wgProxyRelay = proxy
|
conn.wgProxyRelay = proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIP returns the allowed IP of the remote peer
|
||||||
|
func (conn *Conn) AllowedIP() net.IP {
|
||||||
|
return conn.allowedIP
|
||||||
|
}
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,11 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResolvedDomainInfo struct {
|
||||||
|
Prefixes []netip.Prefix
|
||||||
|
ParentDomain domain.Domain
|
||||||
|
}
|
||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
@ -79,6 +84,12 @@ type LocalPeerState struct {
|
|||||||
Routes map[string]struct{}
|
Routes map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone returns a copy of the LocalPeerState
|
||||||
|
func (l LocalPeerState) Clone() LocalPeerState {
|
||||||
|
l.Routes = maps.Clone(l.Routes)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
// SignalState contains the latest state of a signal connection
|
||||||
type SignalState struct {
|
type SignalState struct {
|
||||||
URL string
|
URL string
|
||||||
@ -138,7 +149,7 @@ type Status struct {
|
|||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
nsGroupStates []NSGroupState
|
nsGroupStates []NSGroupState
|
||||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
@ -156,7 +167,7 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
offlinePeers: make([]State, 0),
|
offlinePeers: make([]State, 0),
|
||||||
notifier: newNotifier(),
|
notifier: newNotifier(),
|
||||||
mgmAddress: mgmAddress,
|
mgmAddress: mgmAddress,
|
||||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -496,7 +507,7 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
|||||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return d.localPeer
|
return d.localPeer.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalPeerState updates local peer status
|
// UpdateLocalPeerState updates local peer status
|
||||||
@ -591,16 +602,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
|||||||
d.nsGroupStates = dnsStates
|
d.nsGroupStates = dnsStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.resolvedDomainsStates[domain] = prefixes
|
|
||||||
|
// Store both the original domain pattern and resolved domain
|
||||||
|
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
|
||||||
|
Prefixes: prefixes,
|
||||||
|
ParentDomain: originalDomain,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
delete(d.resolvedDomainsStates, domain)
|
|
||||||
|
// Remove all entries that have this domain as their parent
|
||||||
|
for k, v := range d.resolvedDomainsStates {
|
||||||
|
if v.ParentDomain == domain {
|
||||||
|
delete(d.resolvedDomainsStates, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
@ -676,25 +698,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
// extend the list of stun, turn servers with relay address
|
// extend the list of stun, turn servers with relay address
|
||||||
relayStates := slices.Clone(d.relayStates)
|
relayStates := slices.Clone(d.relayStates)
|
||||||
|
|
||||||
var relayState relay.ProbeResult
|
|
||||||
|
|
||||||
// if the server connection is not established then we will use the general address
|
// if the server connection is not established then we will use the general address
|
||||||
// in case of connection we will use the instance specific address
|
// in case of connection we will use the instance specific address
|
||||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO add their status
|
// TODO add their status
|
||||||
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
|
for _, r := range d.relayMgr.ServerURLs() {
|
||||||
for _, r := range d.relayMgr.ServerURLs() {
|
relayStates = append(relayStates, relay.ProbeResult{
|
||||||
relayStates = append(relayStates, relay.ProbeResult{
|
URI: r,
|
||||||
URI: r,
|
Err: err,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
return relayStates
|
|
||||||
}
|
}
|
||||||
relayState.Err = err
|
return relayStates
|
||||||
}
|
}
|
||||||
|
|
||||||
relayState.URI = instanceAddr
|
relayState := relay.ProbeResult{
|
||||||
|
URI: instanceAddr,
|
||||||
|
}
|
||||||
return append(relayStates, relayState)
|
return append(relayStates, relayState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -704,7 +724,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
|||||||
return d.nsGroupStates
|
return d.nsGroupStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return maps.Clone(d.resolvedDomainsStates)
|
return maps.Clone(d.resolvedDomainsStates)
|
||||||
|
@ -46,8 +46,6 @@ type WorkerICE struct {
|
|||||||
hasRelayOnLocally bool
|
hasRelayOnLocally bool
|
||||||
conn WorkerICECallbacks
|
conn WorkerICECallbacks
|
||||||
|
|
||||||
selectedPriority ConnPriority
|
|
||||||
|
|
||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
muxAgent sync.Mutex
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
var preferredCandidateTypes []ice.CandidateType
|
var preferredCandidateTypes []ice.CandidateType
|
||||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||||
w.selectedPriority = connPriorityICEP2P
|
|
||||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||||
} else {
|
} else {
|
||||||
w.selectedPriority = connPriorityICETurn
|
|
||||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
w.log.Debugf("on ICE conn read to use ready")
|
w.log.Debugf("on ICE conn read to use ready")
|
||||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
@ -268,7 +264,13 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
|||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
// wait local endpoint configuration
|
// wait local endpoint configuration
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
|
addrString := pair.Remote.Address()
|
||||||
|
parsed, err := netip.ParseAddr(addrString)
|
||||||
|
if (err == nil) && (parsed.Is6()) {
|
||||||
|
addrString = fmt.Sprintf("[%s]", addrString)
|
||||||
|
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
|
||||||
|
}
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||||
return
|
return
|
||||||
@ -394,3 +396,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||||
|
if isRelayed(pair) {
|
||||||
|
return connPriorityICETurn
|
||||||
|
} else {
|
||||||
|
return connPriorityICEP2P
|
||||||
|
}
|
||||||
|
}
|
||||||
|
87
client/internal/peerstore/store.go
Normal file
87
client/internal/peerstore/store.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package peerstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is a thread-safe store for peer connections.
|
||||||
|
type Store struct {
|
||||||
|
peerConns map[string]*peer.Conn
|
||||||
|
peerConnsMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnStore() *Store {
|
||||||
|
return &Store{
|
||||||
|
peerConns: make(map[string]*peer.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
_, ok := s.peerConns[pubKey]
|
||||||
|
if ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.peerConns[pubKey] = conn
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
delete(s.peerConns, pubKey)
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return p.WgConfig().AllowedIps, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p.AllowedIP(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeersPubKey() []string {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
return maps.Keys(s.peerConns)
|
||||||
|
}
|
@ -13,12 +13,20 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
handlerTypeDynamic = iota
|
||||||
|
handlerTypeDomain
|
||||||
|
handlerTypeStatic
|
||||||
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
@ -53,7 +61,18 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
func newClientNetworkWatcher(
|
||||||
|
ctx context.Context,
|
||||||
|
dnsRouteInterval time.Duration,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
useNewDNSRoute bool,
|
||||||
|
) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
@ -65,7 +84,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
handler: handlerFromRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouteInterval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
useNewDNSRoute,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -368,10 +397,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
func handlerFromRoute(
|
||||||
if rt.IsDynamic() {
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsRouterInteval time.Duration,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
useNewDNSRoute bool,
|
||||||
|
) RouteHandler {
|
||||||
|
switch handlerType(rt, useNewDNSRoute) {
|
||||||
|
case handlerTypeDomain:
|
||||||
|
return dnsinterceptor.New(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
statusRecorder,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
)
|
||||||
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
return dynamic.NewRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouterInteval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||||
}
|
}
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
}
|
||||||
|
|
||||||
|
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
||||||
|
if !rt.IsDynamic() {
|
||||||
|
return handlerTypeStatic
|
||||||
|
}
|
||||||
|
|
||||||
|
if useNewDNSRoute {
|
||||||
|
return handlerTypeDomain
|
||||||
|
}
|
||||||
|
return handlerTypeDynamic
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute: "route1",
|
currentRoute: "route1",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "relayed routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "p2p routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "current route with bad score should be changed to route with better score",
|
name: "current route with bad score should be changed to route with better score",
|
||||||
statuses: map[route.ID]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fill the test data with random routes
|
||||||
|
for _, tc := range testCases {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
currentRoute := &route.Route{
|
currentRoute := &route.Route{
|
||||||
|
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@ -0,0 +1,356 @@
|
|||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type DnsInterceptor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
route *route.Route
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
dnsServer nbdns.Server
|
||||||
|
currentPeerKey string
|
||||||
|
interceptedDomains domainMap
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
) *DnsInterceptor {
|
||||||
|
return &DnsInterceptor{
|
||||||
|
route: rt,
|
||||||
|
routeRefCounter: routeRefCounter,
|
||||||
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
dnsServer: dnsServer,
|
||||||
|
interceptedDomains: make(domainMap),
|
||||||
|
peerStore: peerStore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) String() string {
|
||||||
|
return d.route.Domains.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
|
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
|
}
|
||||||
|
for _, domain := range d.route.Domains {
|
||||||
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
clear(d.interceptedDomains)
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = peerKey
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = ""
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements the dns.Handler interface
|
||||||
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
d.mu.RLock()
|
||||||
|
peerKey := d.currentPeerKey
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
if peerKey == "" {
|
||||||
|
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||||
|
|
||||||
|
d.continueToNextHandler(w, r, "no current peer key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get upstream IP: %v", err)
|
||||||
|
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &dns.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Net: "udp",
|
||||||
|
}
|
||||||
|
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||||
|
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||||
|
|
||||||
|
var answer []dns.RR
|
||||||
|
if reply != nil {
|
||||||
|
answer = reply.Answer
|
||||||
|
}
|
||||||
|
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||||
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reply.Id = r.Id
|
||||||
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
|
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
// Set Zero bit to signal handler chain to continue
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed writing DNS continue response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||||
|
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||||
|
}
|
||||||
|
return peerAllowedIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||||
|
if r == nil {
|
||||||
|
return fmt.Errorf("received nil DNS message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||||
|
origPattern := ""
|
||||||
|
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||||
|
origPattern = writer.GetOrigPattern()
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDomain := domain.Domain(r.Question[0].Name)
|
||||||
|
|
||||||
|
// already punycode via RegisterHandler()
|
||||||
|
originalDomain := domain.Domain(origPattern)
|
||||||
|
if originalDomain == "" {
|
||||||
|
originalDomain = resolvedDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
var newPrefixes []netip.Prefix
|
||||||
|
for _, answer := range r.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
|
newPrefixes = append(newPrefixes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newPrefixes) > 0 {
|
||||||
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(r); err != nil {
|
||||||
|
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||||
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Add new prefixes
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.route.KeepRoute {
|
||||||
|
// Remove old prefixes
|
||||||
|
for _, prefix := range toRemove {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update domain prefixes using resolved domain as key
|
||||||
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||||
|
|
||||||
|
if len(toAdd) > 0 {
|
||||||
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toAdd)
|
||||||
|
}
|
||||||
|
if len(toRemove) > 0 {
|
||||||
|
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
|
for _, prefix := range oldPrefixes {
|
||||||
|
prefixSet[prefix] = false
|
||||||
|
}
|
||||||
|
for _, prefix := range newPrefixes {
|
||||||
|
if _, exists := prefixSet[prefix]; exists {
|
||||||
|
prefixSet[prefix] = true
|
||||||
|
} else {
|
||||||
|
toAdd = append(toAdd, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for prefix, inUse := range prefixSet {
|
||||||
|
if !inUse {
|
||||||
|
toRemove = append(toRemove, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -74,11 +74,7 @@ func NewRoute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) String() string {
|
func (r *Route) String() string {
|
||||||
s, err := r.route.Domains.String()
|
return r.route.Domains.SafeString()
|
||||||
if err != nil {
|
|
||||||
return r.route.Domains.PunycodeString()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(ctx context.Context) error {
|
func (r *Route) AddRoute(ctx context.Context) error {
|
||||||
@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
|||||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||||
r.dynamicDomains[domain] = updatedPrefixes
|
r.dynamicDomains[domain] = updatedPrefixes
|
||||||
|
|
||||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
@ -12,12 +12,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
@ -32,16 +36,33 @@ import (
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
|
GetClientRoutes() route.HAMap
|
||||||
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
Stop(stateManager *statemanager.Manager)
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ManagerConfig struct {
|
||||||
|
Context context.Context
|
||||||
|
PublicKey string
|
||||||
|
DNSRouteInterval time.Duration
|
||||||
|
WGInterface iface.IWGIface
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
RelayManager *relayClient.Manager
|
||||||
|
InitialRoutes []*route.Route
|
||||||
|
StateManager *statemanager.Manager
|
||||||
|
DNSServer dns.Server
|
||||||
|
PeerStore *peerstore.Store
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
// DefaultManager is the default instance of a route manager
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@ -49,7 +70,7 @@ type DefaultManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter serverRouter
|
serverRouter *serverRouter
|
||||||
sysOps *systemops.SysOps
|
sysOps *systemops.SysOps
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
relayMgr *relayClient.Manager
|
relayMgr *relayClient.Manager
|
||||||
@ -59,51 +80,81 @@ type DefaultManager struct {
|
|||||||
routeRefCounter *refcounter.RouteRefCounter
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
|
clientRoutes route.HAMap
|
||||||
|
dnsServer dns.Server
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
useNewDNSRoute bool
|
||||||
|
disableClientRoutes bool
|
||||||
|
disableServerRoutes bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
ctx context.Context,
|
mCTX, cancel := context.WithCancel(config.Context)
|
||||||
pubKey string,
|
|
||||||
dnsRouteInterval time.Duration,
|
|
||||||
wgInterface iface.IWGIface,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
relayMgr *relayClient.Manager,
|
|
||||||
initialRoutes []*route.Route,
|
|
||||||
) *DefaultManager {
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(wgInterface, notifier)
|
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||||
|
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
dnsRouteInterval: dnsRouteInterval,
|
dnsRouteInterval: config.DNSRouteInterval,
|
||||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||||
relayMgr: relayMgr,
|
relayMgr: config.RelayManager,
|
||||||
routeSelector: routeselector.NewRouteSelector(),
|
sysOps: sysOps,
|
||||||
sysOps: sysOps,
|
statusRecorder: config.StatusRecorder,
|
||||||
statusRecorder: statusRecorder,
|
wgInterface: config.WGInterface,
|
||||||
wgInterface: wgInterface,
|
pubKey: config.PublicKey,
|
||||||
pubKey: pubKey,
|
notifier: notifier,
|
||||||
notifier: notifier,
|
stateManager: config.StateManager,
|
||||||
|
dnsServer: config.DNSServer,
|
||||||
|
peerStore: config.PeerStore,
|
||||||
|
disableClientRoutes: config.DisableClientRoutes,
|
||||||
|
disableServerRoutes: config.DisableServerRoutes,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
// don't proceed with client routes if it is disabled
|
||||||
|
if config.DisableClientRoutes {
|
||||||
|
return dm
|
||||||
|
}
|
||||||
|
|
||||||
|
dm.setupRefCounters()
|
||||||
|
|
||||||
|
if runtime.GOOS == "android" {
|
||||||
|
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||||
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
|
}
|
||||||
|
return dm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) setupRefCounters() {
|
||||||
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
dm.allowedIPsRefCounter = refcounter.New(
|
if netstack.IsEnabled() {
|
||||||
|
m.routeRefCounter = refcounter.New(
|
||||||
|
func(netip.Prefix, struct{}) (struct{}, error) {
|
||||||
|
return struct{}{}, refcounter.ErrIgnore
|
||||||
|
},
|
||||||
|
func(netip.Prefix, struct{}) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.allowedIPsRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||||
// save peerKey to use it in the remove function
|
// save peerKey to use it in the remove function
|
||||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, peerKey string) error {
|
func(prefix netip.Prefix, peerKey string) error {
|
||||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||||
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -112,17 +163,13 @@ func NewManager(
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
|
||||||
}
|
|
||||||
return dm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
if nbnet.CustomRoutingDisabled() {
|
m.routeSelector = m.initSelector()
|
||||||
|
|
||||||
|
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,15 +184,46 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
|
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Routing setup complete")
|
log.Info("Routing setup complete")
|
||||||
return beforePeerHook, afterPeerHook, nil
|
return beforePeerHook, afterPeerHook, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||||
|
var state *SelectorState
|
||||||
|
m.stateManager.RegisterState(state)
|
||||||
|
|
||||||
|
// restore selector state if it exists
|
||||||
|
if err := m.stateManager.LoadState(state); err != nil {
|
||||||
|
log.Warnf("failed to load state: %v", err)
|
||||||
|
return routeselector.NewRouteSelector()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state := m.stateManager.GetState(state); state != nil {
|
||||||
|
if selector, ok := state.(*SelectorState); ok {
|
||||||
|
return (*routeselector.RouteSelector)(selector)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("failed to convert state with type %T to SelectorState", state)
|
||||||
|
}
|
||||||
|
|
||||||
|
return routeselector.NewRouteSelector()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
|
if m.disableServerRoutes {
|
||||||
|
log.Info("server routes are disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if firewall == nil {
|
||||||
|
return errors.New("firewall manager is not set")
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -172,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() {
|
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
@ -181,33 +259,43 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.clientRoutes = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
return nil, nil, m.ctx.Err()
|
return nil
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
}
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.useNewDNSRoute = useNewDNSRoute
|
||||||
|
|
||||||
|
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||||
|
|
||||||
|
if !m.disableClientRoutes {
|
||||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.serverRouter != nil {
|
||||||
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||||
@ -225,9 +313,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return m.routeSelector
|
return m.routeSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
|
||||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||||
return m.clientNetworks
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
return maps.Clone(m.clientRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
|
||||||
|
for id, v := range m.clientRoutes {
|
||||||
|
routes[id.NetID()] = v
|
||||||
|
}
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
@ -247,11 +350,26 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher := newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
m.useNewDNSRoute,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||||
@ -272,7 +390,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher = newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
m.useNewDNSRoute,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
}
|
}
|
||||||
@ -315,7 +444,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
return newServerRoutesMap, newClientRoutesIDMap
|
return newServerRoutesMap, newClientRoutesIDMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||||
_, crMap := m.classifyRoutes(initialRoutes)
|
_, crMap := m.classifyRoutes(initialRoutes)
|
||||||
rs := make([]*route.Route, 0, len(crMap))
|
rs := make([]*route.Route, 0, len(crMap))
|
||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
|
@ -424,9 +424,14 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
routeManager := NewManager(ManagerConfig{
|
||||||
|
Context: ctx,
|
||||||
|
PublicKey: localPeerKey,
|
||||||
|
WGInterface: wgInterface,
|
||||||
|
StatusRecorder: statusRecorder,
|
||||||
|
})
|
||||||
|
|
||||||
_, _, err = routeManager.Init(nil)
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop(nil)
|
defer routeManager.Stop(nil)
|
||||||
@ -436,11 +441,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||||
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@ -15,13 +14,15 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
StopFunc func(manager *statemanager.Manager)
|
GetClientRoutesFunc func() route.HAMap
|
||||||
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||||
|
if m.GetClientRoutesFunc != nil {
|
||||||
|
return m.GetClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
return m.GetClientRoutesWithNetIDFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start mock implementation of Start from Manager interface
|
// Start mock implementation of Start from Manager interface
|
||||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
||||||
}
|
}
|
||||||
|
@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
|||||||
type Counter[Key comparable, I, O any] struct {
|
type Counter[Key comparable, I, O any] struct {
|
||||||
// refCountMap keeps track of the reference Ref for keys
|
// refCountMap keeps track of the reference Ref for keys
|
||||||
refCountMap map[Key]Ref[O]
|
refCountMap map[Key]Ref[O]
|
||||||
refCountMu sync.Mutex
|
mu sync.Mutex
|
||||||
// idMap keeps track of the keys associated with an ID for removal
|
// idMap keeps track of the keys associated with an ID for removal
|
||||||
idMap map[string][]Key
|
idMap map[string][]Key
|
||||||
idMu sync.Mutex
|
|
||||||
add AddFunc[Key, I, O]
|
add AddFunc[Key, I, O]
|
||||||
remove RemoveFunc[Key, O]
|
remove RemoveFunc[Key, O]
|
||||||
}
|
}
|
||||||
@ -72,13 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadData loads the data from the existing counter
|
// LoadData loads the data from the existing counter
|
||||||
|
// The passed counter should not be used any longer after calling this function.
|
||||||
func (rm *Counter[Key, I, O]) LoadData(
|
func (rm *Counter[Key, I, O]) LoadData(
|
||||||
existingCounter *Counter[Key, I, O],
|
existingCounter *Counter[Key, I, O],
|
||||||
) {
|
) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
existingCounter.mu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer existingCounter.mu.Unlock()
|
||||||
|
|
||||||
rm.refCountMap = existingCounter.refCountMap
|
rm.refCountMap = existingCounter.refCountMap
|
||||||
rm.idMap = existingCounter.idMap
|
rm.idMap = existingCounter.idMap
|
||||||
@ -87,8 +87,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
|||||||
// Get retrieves the current reference count and associated data for a key.
|
// Get retrieves the current reference count and associated data for a key.
|
||||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
ref, ok := rm.refCountMap[key]
|
ref, ok := rm.refCountMap[key]
|
||||||
return ref, ok
|
return ref, ok
|
||||||
@ -97,9 +97,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
|||||||
// Increment increments the reference count for the given key.
|
// Increment increments the reference count for the given key.
|
||||||
// If this is the first reference to the key, the AddFunc is called.
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
|
return rm.increment(key, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||||
ref := rm.refCountMap[key]
|
ref := rm.refCountMap[key]
|
||||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||||
|
|
||||||
@ -126,10 +130,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
|||||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||||
// If this is the first reference to the key, the AddFunc is called.
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||||
rm.idMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
ref, err := rm.Increment(key, in)
|
ref, err := rm.increment(key, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ref, fmt.Errorf("with ID: %w", err)
|
return ref, fmt.Errorf("with ID: %w", err)
|
||||||
}
|
}
|
||||||
@ -141,9 +145,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
|||||||
// Decrement decrements the reference count for the given key.
|
// Decrement decrements the reference count for the given key.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
return rm.decrement(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||||
ref, ok := rm.refCountMap[key]
|
ref, ok := rm.refCountMap[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
logCallerF("No reference found for key %v", key)
|
logCallerF("No reference found for key %v", key)
|
||||||
@ -168,12 +175,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
|||||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||||
rm.idMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, key := range rm.idMap[id] {
|
for _, key := range rm.idMap[id] {
|
||||||
if _, err := rm.Decrement(key); err != nil {
|
if _, err := rm.decrement(key); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -184,10 +191,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
|||||||
|
|
||||||
// Flush removes all references and calls RemoveFunc for each key.
|
// Flush removes all references and calls RemoveFunc for each key.
|
||||||
func (rm *Counter[Key, I, O]) Flush() error {
|
func (rm *Counter[Key, I, O]) Flush() error {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for key := range rm.refCountMap {
|
for key := range rm.refCountMap {
|
||||||
@ -206,10 +211,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
|||||||
|
|
||||||
// Clear removes all references without calling RemoveFunc.
|
// Clear removes all references without calling RemoveFunc.
|
||||||
func (rm *Counter[Key, I, O]) Clear() {
|
func (rm *Counter[Key, I, O]) Clear() {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
clear(rm.refCountMap)
|
clear(rm.refCountMap)
|
||||||
clear(rm.idMap)
|
clear(rm.idMap)
|
||||||
@ -217,10 +220,8 @@ func (rm *Counter[Key, I, O]) Clear() {
|
|||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
return json.Marshal(struct {
|
return json.Marshal(struct {
|
||||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
@ -233,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||||
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||||
|
rm.mu.Lock()
|
||||||
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
var temp struct {
|
var temp struct {
|
||||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
IDMap map[string][]Key `json:"idMap"`
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
@ -243,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
|||||||
rm.refCountMap = temp.RefCountMap
|
rm.refCountMap = temp.RefCountMap
|
||||||
rm.idMap = temp.IDMap
|
rm.idMap = temp.IDMap
|
||||||
|
|
||||||
|
if temp.RefCountMap == nil {
|
||||||
|
temp.RefCountMap = map[Key]Ref[O]{}
|
||||||
|
}
|
||||||
|
if temp.IDMap == nil {
|
||||||
|
temp.IDMap = map[string][]Key{}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/route"
|
|
||||||
|
|
||||||
type serverRouter interface {
|
|
||||||
updateRoutes(map[route.ID]*route.Route) error
|
|
||||||
removeFromServerNetwork(*route.Route) error
|
|
||||||
cleanUp()
|
|
||||||
}
|
|
@ -9,8 +9,19 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
type serverRouter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r serverRouter) cleanUp() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||||
return nil, fmt.Errorf("server route not supported on this os")
|
return nil, fmt.Errorf("server route not supported on this os")
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultServerRouter struct {
|
type serverRouter struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[route.ID]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
@ -26,8 +26,8 @@ type defaultServerRouter struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||||
return &defaultServerRouter{
|
return &serverRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[route.ID]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||||
serverRoutesToRemove := make([]route.ID, 0)
|
serverRoutesToRemove := make([]route.ID, 0)
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||||
select {
|
if m.ctx.Err() != nil {
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("Not removing from server network because context is done")
|
log.Infof("Not removing from server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
routerPair, err := routeToRouterPair(route)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.firewall.RemoveNatRule(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.routes, route.ID)
|
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
|
||||||
delete(state.Routes, route.Network.String())
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routerPair, err := routeToRouterPair(route)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.firewall.RemoveNatRule(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.routes, route.ID)
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
delete(state.Routes, route.Network.String())
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||||
select {
|
if m.ctx.Err() != nil {
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("Not adding to server network because context is done")
|
log.Infof("Not adding to server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
routerPair, err := routeToRouterPair(route)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.firewall.AddNatRule(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insert routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.routes[route.ID] = route
|
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
|
||||||
if state.Routes == nil {
|
|
||||||
state.Routes = map[string]struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
routeStr := route.Network.String()
|
|
||||||
if route.IsDynamic() {
|
|
||||||
routeStr = route.Domains.SafeString()
|
|
||||||
}
|
|
||||||
state.Routes[routeStr] = struct{}{}
|
|
||||||
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routerPair, err := routeToRouterPair(route)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.firewall.AddNatRule(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routes[route.ID] = route
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
if state.Routes == nil {
|
||||||
|
state.Routes = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
routeStr := route.Network.String()
|
||||||
|
if route.IsDynamic() {
|
||||||
|
routeStr = route.Domains.SafeString()
|
||||||
|
}
|
||||||
|
state.Routes[routeStr] = struct{}{}
|
||||||
|
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) cleanUp() {
|
func (m *serverRouter) cleanUp() {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
for _, r := range m.routes {
|
for _, r := range m.routes {
|
||||||
|
19
client/internal/routemanager/state.go
Normal file
19
client/internal/routemanager/state.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SelectorState routeselector.RouteSelector
|
||||||
|
|
||||||
|
func (s *SelectorState) Name() string {
|
||||||
|
return "routeselector_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelectorState) MarshalJSON() ([]byte, error) {
|
||||||
|
return (*routeselector.RouteSelector)(s).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelectorState) UnmarshalJSON(data []byte) error {
|
||||||
|
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
|
||||||
|
}
|
@ -2,31 +2,28 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState ExclusionCounter
|
||||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
return "route_state"
|
return "route_state"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
if s.Counter == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sysops := NewSysOps(nil, nil)
|
sysops := NewSysOps(nil, nil)
|
||||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||||
sysops.refCounter.LoadData(s.Counter)
|
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||||
|
|
||||||
return sysops.refCounter.Flush()
|
return sysops.refCounter.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
return (*ExclusionCounter)(s).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||||
|
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||||
|
}
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
@ -57,30 +58,30 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
return nexthop, refcounter.ErrIgnore
|
return nexthop, refcounter.ErrIgnore
|
||||||
}
|
}
|
||||||
|
|
||||||
r.updateState(stateManager)
|
|
||||||
|
|
||||||
return nexthop, err
|
return nexthop, err
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
r.removeFromRouteTable,
|
||||||
// remove from state even if we have trouble removing it from the route table
|
|
||||||
// it could be already gone
|
|
||||||
r.updateState(stateManager)
|
|
||||||
|
|
||||||
return r.removeFromRouteTable(prefix, nexthop)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
refCounter = refcounter.New(
|
||||||
|
func(netip.Prefix, struct{}) (Nexthop, error) {
|
||||||
|
return Nexthop{}, refcounter.ErrIgnore
|
||||||
|
},
|
||||||
|
func(netip.Prefix, Nexthop) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
|
|
||||||
return r.setupHooks(initAddresses)
|
return r.setupHooks(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateState updates state on every change so it will be persisted regularly
|
||||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||||
state := getState(stateManager)
|
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||||
|
|
||||||
state.Counter = r.refCounter
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(state); err != nil {
|
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -336,7 +337,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
return r.removeFromRouteTable(prefix, nextHop)
|
return r.removeFromRouteTable(prefix, nextHop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -347,6 +348,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
|||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
afterHook := func(connID nbnet.ConnectionID) error {
|
||||||
@ -354,6 +357,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
|||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,14 +537,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
|||||||
// Return true if the longest matching prefix is from vpnRoutes
|
// Return true if the longest matching prefix is from vpnRoutes
|
||||||
return isVpn, longestPrefix
|
return isVpn, longestPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
|
||||||
var shutdownState *ShutdownState
|
|
||||||
if state := stateManager.GetState(shutdownState); state != nil {
|
|
||||||
shutdownState = state.(*ShutdownState)
|
|
||||||
} else {
|
|
||||||
shutdownState = &ShutdownState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return shutdownState
|
|
||||||
}
|
|
||||||
|
@ -55,7 +55,7 @@ type ruleParams struct {
|
|||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
// isLegacy determines whether to use the legacy routing setup
|
||||||
func isLegacy() bool {
|
func isLegacy() bool {
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
// setIsLegacy sets the legacy routing setup
|
||||||
@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Error setting up sysctl: %v", err)
|
|
||||||
sysctlFailed = true
|
|
||||||
}
|
|
||||||
originalSysctl = originalValues
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||||
@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = addRoutingTableName(); err != nil {
|
||||||
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,7 +266,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
|||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add route: %w", err)
|
return fmt.Errorf("netlink add route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,7 +289,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,7 +312,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|||||||
if err := netlink.RouteDel(route); err != nil &&
|
if err := netlink.RouteDel(route); err != nil &&
|
||||||
!errors.Is(err, syscall.ESRCH) &&
|
!errors.Is(err, syscall.ESRCH) &&
|
||||||
!errors.Is(err, syscall.ENOENT) &&
|
!errors.Is(err, syscall.ENOENT) &&
|
||||||
!errors.Is(err, syscall.EAFNOSUPPORT) {
|
!isOpErr(err) {
|
||||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
|||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink remove route: %w", err)
|
return fmt.Errorf("netlink remove route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,7 +362,7 @@ func flushRoutes(tableID, family int) error {
|
|||||||
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteDel(&routes[i]); err != nil && !isOpErr(err) {
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !isOpErr(err) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -509,3 +509,13 @@ func hasSeparateRouting() ([]netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
return nil, ErrRoutingIsSeparate
|
return nil, ErrRoutingIsSeparate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpErr(err error) bool {
|
||||||
|
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||||
|
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
log.Debugf("route operation not supported: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -230,10 +230,13 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
|
|||||||
if idx != 0 {
|
if idx != 0 {
|
||||||
intf, err := net.InterfaceByIndex(idx)
|
intf, err := net.InterfaceByIndex(idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return update, fmt.Errorf("get interface name: %w", err)
|
log.Warnf("failed to get interface name for index %d: %v", idx, err)
|
||||||
|
update.Interface = &net.Interface{
|
||||||
|
Index: idx,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
update.Interface = intf
|
||||||
}
|
}
|
||||||
|
|
||||||
update.Interface = intf
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
|
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package routeselector
|
package routeselector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -12,6 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
|
mu sync.RWMutex
|
||||||
selectedRoutes map[route.NetID]struct{}
|
selectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
selectAll bool
|
||||||
}
|
}
|
||||||
@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {
|
|||||||
|
|
||||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||||
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if !appendRoute {
|
if !appendRoute {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
|
|
||||||
// SelectAllRoutes sets the selector to select all routes.
|
// SelectAllRoutes sets the selector to select all routes.
|
||||||
func (rs *RouteSelector) SelectAllRoutes() {
|
func (rs *RouteSelector) SelectAllRoutes() {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
|||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
|||||||
|
|
||||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
|||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return maps.Clone(routes)
|
return maps.Clone(routes)
|
||||||
}
|
}
|
||||||
@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
|||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface
|
||||||
|
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
|
SelectAll bool `json:"select_all"`
|
||||||
|
}{
|
||||||
|
SelectAll: rs.selectAll,
|
||||||
|
SelectedRoutes: rs.selectedRoutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||||
|
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
|
||||||
|
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
|
// Check for null or empty JSON
|
||||||
|
if len(data) == 0 || string(data) == "null" {
|
||||||
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
|
rs.selectAll = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var temp struct {
|
||||||
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
|
SelectAll bool `json:"select_all"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rs.selectedRoutes = temp.SelectedRoutes
|
||||||
|
rs.selectAll = temp.SelectAll
|
||||||
|
|
||||||
|
if rs.selectedRoutes == nil {
|
||||||
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
|||||||
"route2|192.168.0.0/16": {},
|
"route2|192.168.0.0/16": {},
|
||||||
}, filtered)
|
}, filtered)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||||
|
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialState func(rs *routeselector.RouteSelector) error // Setup initial state
|
||||||
|
wantNewSelected []route.NetID // Expected selected routes after new routes appear
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "New routes with initial selectAll state",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
// When selectAll is true, all routes including new ones should be selected
|
||||||
|
wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after specific selection",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
|
||||||
|
},
|
||||||
|
// When specific routes were selected, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route1", "route2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after deselect all",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
// After deselect all, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after deselecting specific routes",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
||||||
|
},
|
||||||
|
// After deselecting specific routes, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route2", "route3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after selecting with append",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
|
||||||
|
},
|
||||||
|
// When routes were appended, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
|
// Setup initial state
|
||||||
|
err := tt.initialState(rs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify selection state with new routes
|
||||||
|
for _, id := range newRoutes {
|
||||||
|
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id),
|
||||||
|
"Route %s selection state incorrect", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional verification using FilterSelected
|
||||||
|
routes := route.HAMap{
|
||||||
|
"route1|10.0.0.0/8": {},
|
||||||
|
"route2|192.168.0.0/16": {},
|
||||||
|
"route3|172.16.0.0/12": {},
|
||||||
|
"route4|10.10.0.0/16": {},
|
||||||
|
"route5|192.168.1.0/24": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelected(routes)
|
||||||
|
expectedLen := len(tt.wantNewSelected)
|
||||||
|
assert.Equal(t, expectedLen, len(filtered),
|
||||||
|
"FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -16,14 +16,39 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
errStateNotRegistered = "state %s not registered"
|
||||||
|
errLoadStateFile = "load state file: %w"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State interface defines the methods that all state types must implement
|
// State interface defines the methods that all state types must implement
|
||||||
type State interface {
|
type State interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanableState interface extends State with cleanup capability
|
||||||
|
type CleanableState interface {
|
||||||
|
State
|
||||||
Cleanup() error
|
Cleanup() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RawState wraps raw JSON data for unregistered states
|
||||||
|
type RawState struct {
|
||||||
|
data json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RawState) Name() string {
|
||||||
|
return "" // This is a placeholder implementation
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
||||||
|
func (r *RawState) MarshalJSON() ([]byte, error) {
|
||||||
|
return r.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Manager handles the persistence and management of various states
|
// Manager handles the persistence and management of various states
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -73,15 +98,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if m.cancel != nil {
|
if m.cancel == nil {
|
||||||
m.cancel()
|
return nil
|
||||||
|
}
|
||||||
|
m.cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-m.done:
|
case <-m.done:
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -139,7 +164,7 @@ func (m *Manager) setState(name string, state State) error {
|
|||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if _, exists := m.states[name]; !exists {
|
if _, exists := m.states[name]; !exists {
|
||||||
return fmt.Errorf("state %s not registered", name)
|
return fmt.Errorf(errStateNotRegistered, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.states[name] = state
|
m.states[name] = state
|
||||||
@ -148,6 +173,63 @@ func (m *Manager) setState(name string, state State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteStateByName handles deletion of states without cleanup.
|
||||||
|
// It doesn't require the state to be registered.
|
||||||
|
func (m *Manager) DeleteStateByName(stateName string) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
rawStates, err := m.loadStateFile(false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errLoadStateFile, err)
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := rawStates[stateName]; !exists {
|
||||||
|
return fmt.Errorf("state %s not found", stateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark state as deleted by setting it to nil and marking it dirty
|
||||||
|
m.states[stateName] = nil
|
||||||
|
m.dirty[stateName] = struct{}{}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllStates removes all states.
|
||||||
|
func (m *Manager) DeleteAllStates() (int, error) {
|
||||||
|
if m == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
rawStates, err := m.loadStateFile(false)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf(errLoadStateFile, err)
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
count := len(rawStates)
|
||||||
|
|
||||||
|
// Mark all states as deleted and dirty
|
||||||
|
for name := range rawStates {
|
||||||
|
m.states[name] = nil
|
||||||
|
m.dirty[name] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) periodicStateSave(ctx context.Context) {
|
func (m *Manager) periodicStateSave(ctx context.Context) {
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@ -178,25 +260,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
bs, err := marshalWithPanicRecovery(m.states)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal states: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
start := time.Now()
|
||||||
go func() {
|
go func() {
|
||||||
data, err := json.MarshalIndent(m.states, "", " ")
|
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||||
if err != nil {
|
|
||||||
done <- fmt.Errorf("marshal states: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gosec
|
|
||||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
|
||||||
done <- fmt.Errorf("write state file: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
done <- nil
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -208,63 +283,175 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||||
|
|
||||||
clear(m.dirty)
|
clear(m.dirty)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadState loads the existing state from the state file
|
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
||||||
func (m *Manager) loadState() error {
|
func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
|
||||||
data, err := os.ReadFile(m.filePath)
|
data, err := os.ReadFile(m.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
log.Debug("state file does not exist")
|
log.Debug("state file does not exist")
|
||||||
return nil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("read state file: %w", err)
|
return nil, fmt.Errorf("read state file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawStates map[string]json.RawMessage
|
var rawStates map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(data, &rawStates); err != nil {
|
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||||
log.Warn("State file appears to be corrupted, attempting to delete it")
|
if deleteCorrupt {
|
||||||
if err := os.Remove(m.filePath); err != nil {
|
log.Warn("State file appears to be corrupted, attempting to delete it", err)
|
||||||
log.Errorf("Failed to delete corrupted state file: %v", err)
|
if err := os.Remove(m.filePath); err != nil {
|
||||||
} else {
|
log.Errorf("Failed to delete corrupted state file: %v", err)
|
||||||
log.Info("State file deleted")
|
} else {
|
||||||
|
log.Info("State file deleted")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unmarshal states: %w", err)
|
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
return rawStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
for name, rawState := range rawStates {
|
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||||
stateType, ok := m.stateTypes[name]
|
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||||
if !ok {
|
stateType, ok := m.stateTypes[name]
|
||||||
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
if !ok {
|
||||||
continue
|
return nil, fmt.Errorf(errStateNotRegistered, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(rawState) == "null" {
|
if string(rawState) == "null" {
|
||||||
continue
|
return nil, nil //nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
statePtr := reflect.New(stateType).Interface().(State)
|
statePtr := reflect.New(stateType).Interface().(State)
|
||||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||||
continue
|
}
|
||||||
}
|
|
||||||
|
|
||||||
m.states[name] = statePtr
|
return statePtr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadState loads a specific state from the state file
|
||||||
|
func (m *Manager) LoadState(state State) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
rawStates, err := m.loadStateFile(false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := state.Name()
|
||||||
|
rawState, exists := rawStates[name]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.states[name] = loadedState
|
||||||
|
if loadedState != nil {
|
||||||
log.Debugf("loaded state: %s", name)
|
log.Debugf("loaded state: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
// cleanupSingleState handles the cleanup of a specific state and returns any error.
|
||||||
// If the cleanup is successful, the state is marked for deletion.
|
// The caller must hold the mutex.
|
||||||
|
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
|
||||||
|
// For unregistered states, preserve the raw JSON
|
||||||
|
if _, registered := m.stateTypes[name]; !registered {
|
||||||
|
m.states[name] = &RawState{data: rawState}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the state
|
||||||
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if loadedState == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if state supports cleanup
|
||||||
|
cleanableState, isCleanable := loadedState.(CleanableState)
|
||||||
|
if !isCleanable {
|
||||||
|
// If it doesn't support cleanup, keep it as-is
|
||||||
|
m.states[name] = loadedState
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform cleanup
|
||||||
|
log.Infof("cleaning up state %s", name)
|
||||||
|
if err := cleanableState.Cleanup(); err != nil {
|
||||||
|
// On cleanup error, preserve the state
|
||||||
|
m.states[name] = loadedState
|
||||||
|
return fmt.Errorf("cleanup state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successfully cleaned up - mark for deletion
|
||||||
|
m.states[name] = nil
|
||||||
|
m.dirty[name] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
|
||||||
|
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
|
||||||
|
func (m *Manager) CleanupStateByName(name string) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if state is registered
|
||||||
|
if _, registered := m.stateTypes[name]; !registered {
|
||||||
|
return fmt.Errorf(errStateNotRegistered, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load raw states from file
|
||||||
|
rawStates, err := m.loadStateFile(false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if state exists in file
|
||||||
|
rawState, exists := rawStates[name]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.cleanupSingleState(name, rawState); err != nil {
|
||||||
|
return fmt.Errorf("%s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
||||||
|
// Unregistered states are preserved in their original state.
|
||||||
func (m *Manager) PerformCleanup() error {
|
func (m *Manager) PerformCleanup() error {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -273,26 +460,63 @@ func (m *Manager) PerformCleanup() error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if err := m.loadState(); err != nil {
|
// Load raw states from file
|
||||||
log.Warnf("Failed to load state during cleanup: %v", err)
|
rawStates, err := m.loadStateFile(true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errLoadStateFile, err)
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for name, state := range m.states {
|
|
||||||
if state == nil {
|
|
||||||
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("client was not shut down properly, cleaning up %s", name)
|
// Process each state in the file
|
||||||
if err := state.Cleanup(); err != nil {
|
for name, rawState := range rawStates {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
if err := m.cleanupSingleState(name, rawState); err != nil {
|
||||||
} else {
|
merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
|
||||||
// mark for deletion on cleanup success
|
|
||||||
m.states[name] = nil
|
|
||||||
m.dirty[name] = struct{}{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSavedStateNames returns all state names that are currently saved in the state file.
|
||||||
|
func (m *Manager) GetSavedStateNames() ([]string, error) {
|
||||||
|
if m == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rawStates, err := m.loadStateFile(false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(errLoadStateFile, err)
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var states []string
|
||||||
|
for name, state := range rawStates {
|
||||||
|
if len(state) != 0 && string(state) != "null" {
|
||||||
|
states = append(states, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return states, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||||
|
var bs []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic during marshal: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bs, err = json.Marshal(v)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return bs, err
|
||||||
|
}
|
||||||
|
@ -1,35 +1,16 @@
|
|||||||
package statemanager
|
package statemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
// It returns an empty string if the path cannot be determined.
|
||||||
func GetDefaultStatePath() string {
|
func GetDefaultStatePath() string {
|
||||||
var path string
|
if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" {
|
||||||
|
return path
|
||||||
switch runtime.GOOS {
|
|
||||||
case "windows":
|
|
||||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
|
||||||
case "darwin", "linux":
|
|
||||||
path = "/var/lib/netbird/state.json"
|
|
||||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
|
||||||
path = "/var/db/netbird/state.json"
|
|
||||||
// ios/android don't need state
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
return filepath.Join(configs.StateDir, "state.json")
|
||||||
dir := filepath.Dir(path)
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return path
|
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,7 @@ func init() {
|
|||||||
// Client struct manage the life circle of background service
|
// Client struct manage the life circle of background service
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
cfgFile string
|
||||||
|
stateFile string
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
@ -73,9 +74,10 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
|
stateFile: stateFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
osName: osName,
|
osName: osName,
|
||||||
osVersion: osVersion,
|
osVersion: osVersion,
|
||||||
@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
|
StateFilePath: c.stateFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@ -269,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := engine.GetClientRoutesWithNetID()
|
|
||||||
routeManager := engine.GetRouteManager()
|
routeManager := engine.GetRouteManager()
|
||||||
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
if routeManager == nil {
|
if routeManager == nil {
|
||||||
return nil, fmt.Errorf("could not get route manager")
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
}
|
}
|
||||||
@ -314,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
|
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||||
var routeSelection []RoutesSelectionInfo
|
var routeSelection []RoutesSelectionInfo
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
domainList := make([]DomainInfo, 0)
|
domainList := make([]DomainInfo, 0)
|
||||||
@ -322,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
|||||||
domainResp := DomainInfo{
|
domainResp := DomainInfo{
|
||||||
Domain: d.SafeString(),
|
Domain: d.SafeString(),
|
||||||
}
|
}
|
||||||
if prefixes, exists := resolvedDomains[d]; exists {
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
var ipStrings []string
|
var ipStrings []string
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range info.Prefixes {
|
||||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||||
}
|
}
|
||||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||||
@ -362,12 +366,12 @@ func (c *Client) SelectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("select route with id: %s", id)
|
log.Debugf("select route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when selecting routes: %s", err)
|
log.Debugf("error when selecting routes: %s", err)
|
||||||
return fmt.Errorf("select routes: %w", err)
|
return fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -389,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("deselect route with id: %s", id)
|
log.Debugf("deselect route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when deselecting routes: %s", err)
|
log.Debugf("error when deselecting routes: %s", err)
|
||||||
return fmt.Errorf("deselect routes: %w", err)
|
return fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,9 +10,10 @@ type Preferences struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences create new Preferences instance
|
// NewPreferences create new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string, stateFilePath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,8 @@ import (
|
|||||||
|
|
||||||
func TestPreferences_DefaultValues(t *testing.T) {
|
func TestPreferences_DefaultValues(t *testing.T) {
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
defaultVar, err := p.GetAdminURL()
|
defaultVar, err := p.GetAdminURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read default value: %s", err)
|
t.Fatalf("failed to read default value: %s", err)
|
||||||
@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) {
|
|||||||
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||||
exampleString := "exampleString"
|
exampleString := "exampleString"
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
|
|
||||||
p.SetAdminURL(exampleString)
|
p.SetAdminURL(exampleString)
|
||||||
resp, err := p.GetAdminURL()
|
resp, err := p.GetAdminURL()
|
||||||
@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) {
|
|||||||
exampleURL := "https://myurl.com:443"
|
exampleURL := "https://myurl.com:443"
|
||||||
examplePresharedKey := "topsecret"
|
examplePresharedKey := "topsecret"
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
|
|
||||||
p.SetAdminURL(exampleURL)
|
p.SetAdminURL(exampleURL)
|
||||||
p.SetManagementURL(exampleURL)
|
p.SetManagementURL(exampleURL)
|
||||||
@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) {
|
|||||||
t.Fatalf("failed to save changes: %s", err)
|
t.Fatalf("failed to save changes: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p = NewPreferences(cfgFile)
|
p = NewPreferences(cfgFile, stateFile)
|
||||||
resp, err := p.GetAdminURL()
|
resp, err := p.GetAdminURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read admin url: %s", err)
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -28,14 +28,14 @@ service DaemonService {
|
|||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
|
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
|
||||||
|
|
||||||
// List available network routes
|
// List available networks
|
||||||
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
|
rpc ListNetworks(ListNetworksRequest) returns (ListNetworksResponse) {}
|
||||||
|
|
||||||
// Select specific routes
|
// Select specific routes
|
||||||
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
|
rpc SelectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {}
|
||||||
|
|
||||||
// Deselect specific routes
|
// Deselect specific routes
|
||||||
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
|
rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {}
|
||||||
|
|
||||||
// DebugBundle creates a debug bundle
|
// DebugBundle creates a debug bundle
|
||||||
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
|
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
|
||||||
@ -45,7 +45,20 @@ service DaemonService {
|
|||||||
|
|
||||||
// SetLogLevel sets the log level of the daemon
|
// SetLogLevel sets the log level of the daemon
|
||||||
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
|
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
|
||||||
};
|
|
||||||
|
// List all states
|
||||||
|
rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {}
|
||||||
|
|
||||||
|
// Clean specific state or all states
|
||||||
|
rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {}
|
||||||
|
|
||||||
|
// Delete specific state or all states
|
||||||
|
rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {}
|
||||||
|
|
||||||
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
|
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
message LoginRequest {
|
message LoginRequest {
|
||||||
// setupKey wiretrustee setup key.
|
// setupKey wiretrustee setup key.
|
||||||
@ -94,6 +107,11 @@ message LoginRequest {
|
|||||||
optional bool networkMonitor = 18;
|
optional bool networkMonitor = 18;
|
||||||
|
|
||||||
optional google.protobuf.Duration dnsRouteInterval = 19;
|
optional google.protobuf.Duration dnsRouteInterval = 19;
|
||||||
|
|
||||||
|
optional bool disable_client_routes = 20;
|
||||||
|
optional bool disable_server_routes = 21;
|
||||||
|
optional bool disable_dns = 22;
|
||||||
|
optional bool disable_firewall = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@ -177,7 +195,7 @@ message PeerState {
|
|||||||
int64 bytesRx = 13;
|
int64 bytesRx = 13;
|
||||||
int64 bytesTx = 14;
|
int64 bytesTx = 14;
|
||||||
bool rosenpassEnabled = 15;
|
bool rosenpassEnabled = 15;
|
||||||
repeated string routes = 16;
|
repeated string networks = 16;
|
||||||
google.protobuf.Duration latency = 17;
|
google.protobuf.Duration latency = 17;
|
||||||
string relayAddress = 18;
|
string relayAddress = 18;
|
||||||
}
|
}
|
||||||
@ -190,7 +208,7 @@ message LocalPeerState {
|
|||||||
string fqdn = 4;
|
string fqdn = 4;
|
||||||
bool rosenpassEnabled = 5;
|
bool rosenpassEnabled = 5;
|
||||||
bool rosenpassPermissive = 6;
|
bool rosenpassPermissive = 6;
|
||||||
repeated string routes = 7;
|
repeated string networks = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
// SignalState contains the latest state of a signal connection
|
||||||
@ -231,20 +249,20 @@ message FullStatus {
|
|||||||
repeated NSGroupState dns_servers = 6;
|
repeated NSGroupState dns_servers = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ListRoutesRequest {
|
message ListNetworksRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
message ListRoutesResponse {
|
message ListNetworksResponse {
|
||||||
repeated Route routes = 1;
|
repeated Network routes = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SelectRoutesRequest {
|
message SelectNetworksRequest {
|
||||||
repeated string routeIDs = 1;
|
repeated string networkIDs = 1;
|
||||||
bool append = 2;
|
bool append = 2;
|
||||||
bool all = 3;
|
bool all = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SelectRoutesResponse {
|
message SelectNetworksResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
message IPList {
|
message IPList {
|
||||||
@ -252,9 +270,9 @@ message IPList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
message Route {
|
message Network {
|
||||||
string ID = 1;
|
string ID = 1;
|
||||||
string network = 2;
|
string range = 2;
|
||||||
bool selected = 3;
|
bool selected = 3;
|
||||||
repeated string domains = 4;
|
repeated string domains = 4;
|
||||||
map<string, IPList> resolvedIPs = 5;
|
map<string, IPList> resolvedIPs = 5;
|
||||||
@ -293,4 +311,46 @@ message SetLogLevelRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message SetLogLevelResponse {
|
message SetLogLevelResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// State represents a daemon state entry
|
||||||
|
message State {
|
||||||
|
string name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListStatesRequest is empty as it requires no parameters
|
||||||
|
message ListStatesRequest {}
|
||||||
|
|
||||||
|
// ListStatesResponse contains a list of states
|
||||||
|
message ListStatesResponse {
|
||||||
|
repeated State states = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanStateRequest for cleaning states
|
||||||
|
message CleanStateRequest {
|
||||||
|
string state_name = 1;
|
||||||
|
bool all = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanStateResponse contains the result of the clean operation
|
||||||
|
message CleanStateResponse {
|
||||||
|
int32 cleaned_states = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteStateRequest for deleting states
|
||||||
|
message DeleteStateRequest {
|
||||||
|
string state_name = 1;
|
||||||
|
bool all = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteStateResponse contains the result of the delete operation
|
||||||
|
message DeleteStateResponse {
|
||||||
|
int32 deleted_states = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message SetNetworkMapPersistenceRequest {
|
||||||
|
bool enabled = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SetNetworkMapPersistenceResponse {}
|
||||||
|
@ -31,18 +31,26 @@ type DaemonServiceClient interface {
|
|||||||
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
||||||
// List available network routes
|
// List available networks
|
||||||
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
|
ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error)
|
||||||
// Select specific routes
|
// Select specific routes
|
||||||
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
|
SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error)
|
||||||
// Deselect specific routes
|
// Deselect specific routes
|
||||||
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
|
DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error)
|
||||||
// DebugBundle creates a debug bundle
|
// DebugBundle creates a debug bundle
|
||||||
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
|
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
|
||||||
// GetLogLevel gets the log level of the daemon
|
// GetLogLevel gets the log level of the daemon
|
||||||
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
|
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
|
||||||
// SetLogLevel sets the log level of the daemon
|
// SetLogLevel sets the log level of the daemon
|
||||||
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
|
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
|
||||||
|
// List all states
|
||||||
|
ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error)
|
||||||
|
// Clean specific state or all states
|
||||||
|
CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error)
|
||||||
|
// Delete specific state or all states
|
||||||
|
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
||||||
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
|
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@ -107,27 +115,27 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
|
func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) {
|
||||||
out := new(ListRoutesResponse)
|
out := new(ListNetworksResponse)
|
||||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
|
func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
|
||||||
out := new(SelectRoutesResponse)
|
out := new(SelectNetworksResponse)
|
||||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
|
func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
|
||||||
out := new(SelectRoutesResponse)
|
out := new(SelectNetworksResponse)
|
||||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -161,6 +169,42 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
|
||||||
|
out := new(ListStatesResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
|
||||||
|
out := new(CleanStateResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
|
||||||
|
out := new(DeleteStateResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
|
||||||
|
out := new(SetNetworkMapPersistenceResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@ -178,18 +222,26 @@ type DaemonServiceServer interface {
|
|||||||
Down(context.Context, *DownRequest) (*DownResponse, error)
|
Down(context.Context, *DownRequest) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
||||||
// List available network routes
|
// List available networks
|
||||||
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
|
ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error)
|
||||||
// Select specific routes
|
// Select specific routes
|
||||||
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
|
SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error)
|
||||||
// Deselect specific routes
|
// Deselect specific routes
|
||||||
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
|
DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error)
|
||||||
// DebugBundle creates a debug bundle
|
// DebugBundle creates a debug bundle
|
||||||
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
|
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
|
||||||
// GetLogLevel gets the log level of the daemon
|
// GetLogLevel gets the log level of the daemon
|
||||||
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
|
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
|
||||||
// SetLogLevel sets the log level of the daemon
|
// SetLogLevel sets the log level of the daemon
|
||||||
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
|
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
|
||||||
|
// List all states
|
||||||
|
ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error)
|
||||||
|
// Clean specific state or all states
|
||||||
|
CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error)
|
||||||
|
// Delete specific state or all states
|
||||||
|
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
||||||
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
|
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,14 +267,14 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
|
|||||||
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
|
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
|
func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
|
func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
|
func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
|
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
|
||||||
@ -233,6 +285,18 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve
|
|||||||
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
|
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@ -354,56 +418,56 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(ListRoutesRequest)
|
in := new(ListNetworksRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
|
return srv.(DaemonServiceServer).ListNetworks(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/daemon.DaemonService/ListRoutes",
|
FullMethod: "/daemon.DaemonService/ListNetworks",
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
|
return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(SelectRoutesRequest)
|
in := new(SelectNetworksRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
|
return srv.(DaemonServiceServer).SelectNetworks(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/daemon.DaemonService/SelectRoutes",
|
FullMethod: "/daemon.DaemonService/SelectNetworks",
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
|
return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(SelectRoutesRequest)
|
in := new(SelectNetworksRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
|
return srv.(DaemonServiceServer).DeselectNetworks(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/daemon.DaemonService/DeselectRoutes",
|
FullMethod: "/daemon.DaemonService/DeselectNetworks",
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
|
return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
@ -462,6 +526,78 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(ListStatesRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).ListStates(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/ListStates",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(CleanStateRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).CleanState(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/CleanState",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(DeleteStateRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).DeleteState(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/DeleteState",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(SetNetworkMapPersistenceRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@ -494,16 +630,16 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
Handler: _DaemonService_GetConfig_Handler,
|
Handler: _DaemonService_GetConfig_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "ListRoutes",
|
MethodName: "ListNetworks",
|
||||||
Handler: _DaemonService_ListRoutes_Handler,
|
Handler: _DaemonService_ListNetworks_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "SelectRoutes",
|
MethodName: "SelectNetworks",
|
||||||
Handler: _DaemonService_SelectRoutes_Handler,
|
Handler: _DaemonService_SelectNetworks_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "DeselectRoutes",
|
MethodName: "DeselectNetworks",
|
||||||
Handler: _DaemonService_DeselectRoutes_Handler,
|
Handler: _DaemonService_DeselectNetworks_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "DebugBundle",
|
MethodName: "DebugBundle",
|
||||||
@ -517,6 +653,22 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "SetLogLevel",
|
MethodName: "SetLogLevel",
|
||||||
Handler: _DaemonService_SetLogLevel_Handler,
|
Handler: _DaemonService_SetLogLevel_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "ListStates",
|
||||||
|
Handler: _DaemonService_ListStates_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "CleanState",
|
||||||
|
Handler: _DaemonService_CleanState_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "DeleteState",
|
||||||
|
Handler: _DaemonService_DeleteState_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "SetNetworkMapPersistence",
|
||||||
|
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
Streams: []grpc.StreamDesc{},
|
||||||
Metadata: "daemon.proto",
|
Metadata: "daemon.proto",
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user