mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-03 17:31:13 +01:00
Merge branch 'main' into routes-get-account-refactoring
# Conflicts: # .github/workflows/golang-test-linux.yml # go.mod # go.sum # management/server/account.go # management/server/account_test.go # management/server/http/handler.go # management/server/http/handlers/peers/peers_handler.go # management/server/http/middleware/auth_middleware.go # management/server/http/middleware/auth_middleware_test.go # management/server/http/testing/benchmarks/peers_handler_benchmark_test.go # management/server/http/testing/benchmarks/users_handler_benchmark_test.go # management/server/integrated_validator.go # management/server/mock_server/account_mock.go # management/server/peer.go # management/server/status/error.go # management/server/store/sql_store.go # management/server/store/sql_store_test.go # management/server/user.go
This commit is contained in:
commit
e326cc7412
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test Code Darwin
|
name: "Darwin"
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -12,9 +12,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
name: "Client / Unit"
|
||||||
matrix:
|
|
||||||
store: ['sqlite']
|
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
@ -44,4 +42,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
|
8
.github/workflows/golang-test-freebsd.yml
vendored
8
.github/workflows/golang-test-freebsd.yml
vendored
@ -1,5 +1,4 @@
|
|||||||
|
name: "FreeBSD"
|
||||||
name: Test Code FreeBSD
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -13,6 +12,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@ -24,7 +24,7 @@ jobs:
|
|||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.1"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go
|
pkg install -y go pkgconf xorg
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
@ -33,7 +33,7 @@ jobs:
|
|||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -timeout 1m -failfast ./encryption/...
|
||||||
|
160
.github/workflows/golang-test-linux.yml
vendored
160
.github/workflows/golang-test-linux.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test Code Linux
|
name: Linux
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -12,11 +12,21 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-cache:
|
build-cache:
|
||||||
|
name: "Build Cache"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
outputs:
|
||||||
|
management: ${{ steps.filter.outputs.management }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
management:
|
||||||
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@ -39,7 +49,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
@ -89,6 +98,7 @@ jobs:
|
|||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@ -134,9 +144,116 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_relay:
|
||||||
|
name: "Relay / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
|
test_signal:
|
||||||
|
name: "Signal / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
test_management:
|
test_management:
|
||||||
|
name: "Management / Unit"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@ -194,10 +311,17 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
go test -tags=devcert \
|
||||||
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
|
-timeout 10m ./management/...
|
||||||
|
|
||||||
benchmark:
|
benchmark:
|
||||||
|
name: "Management / Benchmark"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -254,10 +378,17 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./...
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
|
name: "Management / Benchmark (API)"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -314,10 +445,19 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=benchmark \
|
||||||
|
-run=^$ \
|
||||||
|
-bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
|
name: "Management / Integration"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -363,9 +503,15 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=integration \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 10m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
|
5
.github/workflows/golang-test-windows.yml
vendored
5
.github/workflows/golang-test-windows.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test Code Windows
|
name: "Windows"
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -14,6 +14,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@ -65,7 +66,7 @@ jobs:
|
|||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
11
.github/workflows/golangci-lint.yml
vendored
11
.github/workflows/golangci-lint.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: golangci-lint
|
name: Lint
|
||||||
on: [pull_request]
|
on: [pull_request]
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
@ -27,7 +27,14 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||||
name: lint
|
include:
|
||||||
|
- os: macos-latest
|
||||||
|
display_name: Darwin
|
||||||
|
- os: windows-latest
|
||||||
|
display_name: Windows
|
||||||
|
- os: ubuntu-latest
|
||||||
|
display_name: Linux
|
||||||
|
name: ${{ matrix.display_name }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
name: Mobile build validation
|
name: Mobile
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -12,6 +12,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
android_build:
|
android_build:
|
||||||
|
name: "Android / Build"
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@ -47,6 +48,7 @@ jobs:
|
|||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||||
ios_build:
|
ios_build:
|
||||||
|
name: "iOS / Build"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@ -9,10 +9,10 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.17"
|
SIGN_PIPE_VER: "v0.0.18"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
|||||||
infrastructure_files/setup-*.env
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
vendor/
|
||||||
|
@ -103,7 +103,7 @@ linters:
|
|||||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
issues:
|
issues:
|
||||||
# Maximum count of issues with the same text.
|
# Maximum count of issues with the same text.
|
||||||
|
@ -50,10 +50,12 @@ nfpms:
|
|||||||
- netbird-ui
|
- netbird-ui
|
||||||
formats:
|
formats:
|
||||||
- deb
|
- deb
|
||||||
|
scripts:
|
||||||
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
@ -67,10 +69,12 @@ nfpms:
|
|||||||
- netbird-ui
|
- netbird-ui
|
||||||
formats:
|
formats:
|
||||||
- rpm
|
- rpm
|
||||||
|
scripts:
|
||||||
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
2
AUTHORS
2
AUTHORS
@ -1,3 +1,3 @@
|
|||||||
Mikhail Bragin (https://github.com/braginini)
|
Mikhail Bragin (https://github.com/braginini)
|
||||||
Maycon Santos (https://github.com/mlsmaycon)
|
Maycon Santos (https://github.com/mlsmaycon)
|
||||||
Wiretrustee UG (haftungsbeschränkt)
|
NetBird GmbH
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
We are incredibly thankful for the contributions we receive from the community.
|
We are incredibly thankful for the contributions we receive from the community.
|
||||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||||
as BSD-3 while allowing Wiretrustee to build a sustainable business.
|
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||||
|
|
||||||
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
|
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||||
our software. A CLA enables Wiretrustee to safely commercialize our products
|
our software. A CLA enables NetBird to safely commercialize our products
|
||||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||||
ability to use the project in their own projects or businesses, to republish modified
|
ability to use the project in their own projects or businesses, to republish modified
|
||||||
source, or to completely fork the project.
|
source, or to completely fork the project.
|
||||||
@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
|
|||||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||||
carefully review all the terms of the actual CLA before agreeing.
|
carefully review all the terms of the actual CLA before agreeing.
|
||||||
|
|
||||||
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
|
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||||
in commercial products.
|
in commercial products.
|
||||||
</li>
|
</li>
|
||||||
|
|
||||||
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
|
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||||
license to use that patent including within commercial products. You also agree that you
|
license to use that patent including within commercial products. You also agree that you
|
||||||
have permission to grant this license.
|
have permission to grant this license.
|
||||||
</li>
|
</li>
|
||||||
@ -45,7 +45,7 @@ more.
|
|||||||
# Why require a CLA?
|
# Why require a CLA?
|
||||||
|
|
||||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||||
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
|
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||||
products.
|
products.
|
||||||
|
|
||||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||||
@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
|
|||||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||||
|
|
||||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
|
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||||
require you to sign again.
|
require you to sign again.
|
||||||
|
|
||||||
# Legal Terms and Agreement
|
# Legal Terms and Agreement
|
||||||
|
|
||||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
|
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||||
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
|
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||||
your own Contributions for any other purpose.
|
your own Contributions for any other purpose.
|
||||||
|
|
||||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||||
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
|
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||||
You reserve all right, title, and interest in and to Your Contributions.
|
You reserve all right, title, and interest in and to Your Contributions.
|
||||||
|
|
||||||
1. Definitions.
|
1. Definitions.
|
||||||
|
|
||||||
```
|
```
|
||||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||||
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
|
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||||
entities that control, are controlled by, or are under common control with that entity are considered
|
entities that control, are controlled by, or are under common control with that entity are considered
|
||||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||||
@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
```
|
```
|
||||||
```
|
```
|
||||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||||
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
|
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||||
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
|
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||||
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
|
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||||
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
|
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||||
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
|
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||||
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
|
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||||
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||||
@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||||
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
|
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||||
with Wiretrustee.
|
with NetBird.
|
||||||
|
|
||||||
|
|
||||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||||
@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
|
||||||
|
|
||||||
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
|
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||||
|
|
||||||
|
|
||||||
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
|
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||||
representations inaccurate in any respect.
|
representations inaccurate in any respect.
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
|||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|
||||||
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
|
Copyright (c) 2022 NetBird GmbH & AUTHORS
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||||
|
Webinar: How to Achieve Zero Trust Access to Kubernetes — Effortlessly
|
||||||
|
</a>
|
||||||
|
<br/>
|
||||||
|
<br/>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img width="234" src="docs/media/logo-full.png"/>
|
<img width="234" src="docs/media/logo-full.png"/>
|
||||||
</p>
|
</p>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
FROM alpine:3.21.0
|
FROM alpine:3.21.3
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
|
@ -9,6 +9,7 @@ USER netbird:netbird
|
|||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENV NB_USE_NETSTACK_MODE=true
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||||
ENV NB_CONFIG=config.json
|
ENV NB_CONFIG=config.json
|
||||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
ENV NB_DISABLE_DNS=true
|
ENV NB_DISABLE_DNS=true
|
||||||
|
@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd),
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Creating debug bundle...")
|
cmd.Println("Creating debug bundle...")
|
||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStatusOutput(cmd *cobra.Command) string {
|
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
statusResp, err := getStatus(cmd.Context())
|
statusResp, err := getStatus(cmd.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
}
|
}
|
||||||
|
@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
var dnsLabelsReq []string
|
||||||
|
if dnsLabelsValidated != nil {
|
||||||
|
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||||
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
@ -38,6 +38,7 @@ const (
|
|||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -73,6 +74,7 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
|
blockLANAccess bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Debug(err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
@ -2,107 +2,20 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type peerStateDetailOutput struct {
|
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
|
||||||
Status string `json:"status" yaml:"status"`
|
|
||||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
|
||||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
|
||||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
|
||||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
|
||||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
|
||||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
|
||||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
|
||||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
|
||||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
Networks []string `json:"networks" yaml:"networks"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type peersStateOutput struct {
|
|
||||||
Total int `json:"total" yaml:"total"`
|
|
||||||
Connected int `json:"connected" yaml:"connected"`
|
|
||||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type signalStateOutput struct {
|
|
||||||
URL string `json:"url" yaml:"url"`
|
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type managementStateOutput struct {
|
|
||||||
URL string `json:"url" yaml:"url"`
|
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type relayStateOutputDetail struct {
|
|
||||||
URI string `json:"uri" yaml:"uri"`
|
|
||||||
Available bool `json:"available" yaml:"available"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type relayStateOutput struct {
|
|
||||||
Total int `json:"total" yaml:"total"`
|
|
||||||
Available int `json:"available" yaml:"available"`
|
|
||||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type iceCandidateType struct {
|
|
||||||
Local string `json:"local" yaml:"local"`
|
|
||||||
Remote string `json:"remote" yaml:"remote"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type nsServerGroupStateOutput struct {
|
|
||||||
Servers []string `json:"servers" yaml:"servers"`
|
|
||||||
Domains []string `json:"domains" yaml:"domains"`
|
|
||||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type statusOutputOverview struct {
|
|
||||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
|
||||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
|
||||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
|
||||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
|
||||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
|
||||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
|
||||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
Networks []string `json:"networks" yaml:"networks"`
|
|
||||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
ipv4Flag bool
|
ipv4Flag bool
|
||||||
@ -173,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||||
|
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||||
case jsonFlag:
|
case jsonFlag:
|
||||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -214,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
|
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "disconnected", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
@ -251,175 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
|
||||||
pbFullStatus := resp.GetFullStatus()
|
|
||||||
|
|
||||||
managementState := pbFullStatus.GetManagementState()
|
|
||||||
managementOverview := managementStateOutput{
|
|
||||||
URL: managementState.GetURL(),
|
|
||||||
Connected: managementState.GetConnected(),
|
|
||||||
Error: managementState.Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
signalState := pbFullStatus.GetSignalState()
|
|
||||||
signalOverview := signalStateOutput{
|
|
||||||
URL: signalState.GetURL(),
|
|
||||||
Connected: signalState.GetConnected(),
|
|
||||||
Error: signalState.Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
|
||||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
|
||||||
|
|
||||||
overview := statusOutputOverview{
|
|
||||||
Peers: peersOverview,
|
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
DaemonVersion: resp.GetDaemonVersion(),
|
|
||||||
ManagementState: managementOverview,
|
|
||||||
SignalState: signalOverview,
|
|
||||||
Relays: relayOverview,
|
|
||||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
|
||||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
|
||||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
|
||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
|
||||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
|
||||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
|
||||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
|
||||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if anonymizeFlag {
|
|
||||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
|
||||||
anonymizeOverview(anonymizer, &overview)
|
|
||||||
}
|
|
||||||
|
|
||||||
return overview
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
|
||||||
var relayStateDetail []relayStateOutputDetail
|
|
||||||
|
|
||||||
var relaysAvailable int
|
|
||||||
for _, relay := range relays {
|
|
||||||
available := relay.GetAvailable()
|
|
||||||
relayStateDetail = append(relayStateDetail,
|
|
||||||
relayStateOutputDetail{
|
|
||||||
URI: relay.URI,
|
|
||||||
Available: available,
|
|
||||||
Error: relay.GetError(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if available {
|
|
||||||
relaysAvailable++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return relayStateOutput{
|
|
||||||
Total: len(relays),
|
|
||||||
Available: relaysAvailable,
|
|
||||||
Details: relayStateDetail,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
|
||||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
|
||||||
for _, pbNsGroupServer := range servers {
|
|
||||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
|
||||||
Servers: pbNsGroupServer.GetServers(),
|
|
||||||
Domains: pbNsGroupServer.GetDomains(),
|
|
||||||
Enabled: pbNsGroupServer.GetEnabled(),
|
|
||||||
Error: pbNsGroupServer.GetError(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return mappedNSGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|
||||||
var peersStateDetail []peerStateDetailOutput
|
|
||||||
peersConnected := 0
|
|
||||||
for _, pbPeerState := range peers {
|
|
||||||
localICE := ""
|
|
||||||
remoteICE := ""
|
|
||||||
localICEEndpoint := ""
|
|
||||||
remoteICEEndpoint := ""
|
|
||||||
relayServerAddress := ""
|
|
||||||
connType := ""
|
|
||||||
lastHandshake := time.Time{}
|
|
||||||
transferReceived := int64(0)
|
|
||||||
transferSent := int64(0)
|
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
|
||||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isPeerConnected {
|
|
||||||
peersConnected++
|
|
||||||
|
|
||||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
|
||||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
|
||||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
|
||||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
|
||||||
connType = "P2P"
|
|
||||||
if pbPeerState.Relayed {
|
|
||||||
connType = "Relayed"
|
|
||||||
}
|
|
||||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
|
||||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
|
||||||
transferReceived = pbPeerState.GetBytesRx()
|
|
||||||
transferSent = pbPeerState.GetBytesTx()
|
|
||||||
}
|
|
||||||
|
|
||||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
|
||||||
peerState := peerStateDetailOutput{
|
|
||||||
IP: pbPeerState.GetIP(),
|
|
||||||
PubKey: pbPeerState.GetPubKey(),
|
|
||||||
Status: pbPeerState.GetConnStatus(),
|
|
||||||
LastStatusUpdate: timeLocal,
|
|
||||||
ConnType: connType,
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: localICE,
|
|
||||||
Remote: remoteICE,
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: localICEEndpoint,
|
|
||||||
Remote: remoteICEEndpoint,
|
|
||||||
},
|
|
||||||
RelayAddress: relayServerAddress,
|
|
||||||
FQDN: pbPeerState.GetFqdn(),
|
|
||||||
LastWireguardHandshake: lastHandshake,
|
|
||||||
TransferReceived: transferReceived,
|
|
||||||
TransferSent: transferSent,
|
|
||||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
|
||||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
|
||||||
Routes: pbPeerState.GetNetworks(),
|
|
||||||
Networks: pbPeerState.GetNetworks(),
|
|
||||||
}
|
|
||||||
|
|
||||||
peersStateDetail = append(peersStateDetail, peerState)
|
|
||||||
}
|
|
||||||
|
|
||||||
sortPeersByIP(peersStateDetail)
|
|
||||||
|
|
||||||
peersOverview := peersStateOutput{
|
|
||||||
Total: len(peersStateDetail),
|
|
||||||
Connected: peersConnected,
|
|
||||||
Details: peersStateDetail,
|
|
||||||
}
|
|
||||||
return peersOverview
|
|
||||||
}
|
|
||||||
|
|
||||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
|
||||||
if len(peersStateDetail) > 0 {
|
|
||||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
|
||||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
|
||||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
|
||||||
return iAddr.Compare(jAddr) == -1
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInterfaceIP(interfaceIP string) string {
|
func parseInterfaceIP(interfaceIP string) string {
|
||||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -427,452 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("%s\n", ip)
|
return fmt.Sprintf("%s\n", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
|
||||||
jsonBytes, err := json.Marshal(overview)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("json marshal failed")
|
|
||||||
}
|
|
||||||
return string(jsonBytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
|
||||||
yamlBytes, err := yaml.Marshal(overview)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("yaml marshal failed")
|
|
||||||
}
|
|
||||||
return string(yamlBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
|
||||||
var managementConnString string
|
|
||||||
if overview.ManagementState.Connected {
|
|
||||||
managementConnString = "Connected"
|
|
||||||
if showURL {
|
|
||||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
managementConnString = "Disconnected"
|
|
||||||
if overview.ManagementState.Error != "" {
|
|
||||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var signalConnString string
|
|
||||||
if overview.SignalState.Connected {
|
|
||||||
signalConnString = "Connected"
|
|
||||||
if showURL {
|
|
||||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
signalConnString = "Disconnected"
|
|
||||||
if overview.SignalState.Error != "" {
|
|
||||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interfaceTypeString := "Userspace"
|
|
||||||
interfaceIP := overview.IP
|
|
||||||
if overview.KernelInterface {
|
|
||||||
interfaceTypeString = "Kernel"
|
|
||||||
} else if overview.IP == "" {
|
|
||||||
interfaceTypeString = "N/A"
|
|
||||||
interfaceIP = "N/A"
|
|
||||||
}
|
|
||||||
|
|
||||||
var relaysString string
|
|
||||||
if showRelays {
|
|
||||||
for _, relay := range overview.Relays.Details {
|
|
||||||
available := "Available"
|
|
||||||
reason := ""
|
|
||||||
if !relay.Available {
|
|
||||||
available = "Unavailable"
|
|
||||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
|
||||||
}
|
|
||||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
|
||||||
}
|
|
||||||
|
|
||||||
networks := "-"
|
|
||||||
if len(overview.Networks) > 0 {
|
|
||||||
sort.Strings(overview.Networks)
|
|
||||||
networks = strings.Join(overview.Networks, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
var dnsServersString string
|
|
||||||
if showNameServers {
|
|
||||||
for _, nsServerGroup := range overview.NSServerGroups {
|
|
||||||
enabled := "Available"
|
|
||||||
if !nsServerGroup.Enabled {
|
|
||||||
enabled = "Unavailable"
|
|
||||||
}
|
|
||||||
errorString := ""
|
|
||||||
if nsServerGroup.Error != "" {
|
|
||||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
|
||||||
errorString = strings.TrimSpace(errorString)
|
|
||||||
}
|
|
||||||
|
|
||||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
|
||||||
if domainsString == "" {
|
|
||||||
domainsString = "." // Show "." for the default zone
|
|
||||||
}
|
|
||||||
dnsServersString += fmt.Sprintf(
|
|
||||||
"\n [%s] for [%s] is %s%s",
|
|
||||||
strings.Join(nsServerGroup.Servers, ", "),
|
|
||||||
domainsString,
|
|
||||||
enabled,
|
|
||||||
errorString,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
|
||||||
}
|
|
||||||
|
|
||||||
rosenpassEnabledStatus := "false"
|
|
||||||
if overview.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "true"
|
|
||||||
if overview.RosenpassPermissive {
|
|
||||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
|
||||||
|
|
||||||
goos := runtime.GOOS
|
|
||||||
goarch := runtime.GOARCH
|
|
||||||
goarm := ""
|
|
||||||
if goarch == "arm" {
|
|
||||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
|
||||||
}
|
|
||||||
|
|
||||||
summary := fmt.Sprintf(
|
|
||||||
"OS: %s\n"+
|
|
||||||
"Daemon version: %s\n"+
|
|
||||||
"CLI version: %s\n"+
|
|
||||||
"Management: %s\n"+
|
|
||||||
"Signal: %s\n"+
|
|
||||||
"Relays: %s\n"+
|
|
||||||
"Nameservers: %s\n"+
|
|
||||||
"FQDN: %s\n"+
|
|
||||||
"NetBird IP: %s\n"+
|
|
||||||
"Interface type: %s\n"+
|
|
||||||
"Quantum resistance: %s\n"+
|
|
||||||
"Routes: %s\n"+
|
|
||||||
"Networks: %s\n"+
|
|
||||||
"Peers count: %s\n",
|
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
|
||||||
overview.DaemonVersion,
|
|
||||||
version.NetbirdVersion(),
|
|
||||||
managementConnString,
|
|
||||||
signalConnString,
|
|
||||||
relaysString,
|
|
||||||
dnsServersString,
|
|
||||||
overview.FQDN,
|
|
||||||
interfaceIP,
|
|
||||||
interfaceTypeString,
|
|
||||||
rosenpassEnabledStatus,
|
|
||||||
networks,
|
|
||||||
networks,
|
|
||||||
peersCountString,
|
|
||||||
)
|
|
||||||
return summary
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
|
||||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
|
||||||
summary := parseGeneralSummary(overview, true, true, true)
|
|
||||||
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"Peers detail:"+
|
|
||||||
"%s\n"+
|
|
||||||
"%s",
|
|
||||||
parsedPeersString,
|
|
||||||
summary,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
|
||||||
var (
|
|
||||||
peersString = ""
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, peerState := range peers.Details {
|
|
||||||
|
|
||||||
localICE := "-"
|
|
||||||
if peerState.IceCandidateType.Local != "" {
|
|
||||||
localICE = peerState.IceCandidateType.Local
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteICE := "-"
|
|
||||||
if peerState.IceCandidateType.Remote != "" {
|
|
||||||
remoteICE = peerState.IceCandidateType.Remote
|
|
||||||
}
|
|
||||||
|
|
||||||
localICEEndpoint := "-"
|
|
||||||
if peerState.IceCandidateEndpoint.Local != "" {
|
|
||||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteICEEndpoint := "-"
|
|
||||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
|
||||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
|
||||||
}
|
|
||||||
|
|
||||||
rosenpassEnabledStatus := "false"
|
|
||||||
if rosenpassEnabled {
|
|
||||||
if peerState.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "true"
|
|
||||||
} else {
|
|
||||||
if rosenpassPermissive {
|
|
||||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
|
||||||
} else {
|
|
||||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if peerState.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
networks := "-"
|
|
||||||
if len(peerState.Networks) > 0 {
|
|
||||||
sort.Strings(peerState.Networks)
|
|
||||||
networks = strings.Join(peerState.Networks, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
|
||||||
"\n %s:\n"+
|
|
||||||
" NetBird IP: %s\n"+
|
|
||||||
" Public key: %s\n"+
|
|
||||||
" Status: %s\n"+
|
|
||||||
" -- detail --\n"+
|
|
||||||
" Connection type: %s\n"+
|
|
||||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
|
||||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
|
||||||
" Relay server address: %s\n"+
|
|
||||||
" Last connection update: %s\n"+
|
|
||||||
" Last WireGuard handshake: %s\n"+
|
|
||||||
" Transfer status (received/sent) %s/%s\n"+
|
|
||||||
" Quantum resistance: %s\n"+
|
|
||||||
" Routes: %s\n"+
|
|
||||||
" Networks: %s\n"+
|
|
||||||
" Latency: %s\n",
|
|
||||||
peerState.FQDN,
|
|
||||||
peerState.IP,
|
|
||||||
peerState.PubKey,
|
|
||||||
peerState.Status,
|
|
||||||
peerState.ConnType,
|
|
||||||
localICE,
|
|
||||||
remoteICE,
|
|
||||||
localICEEndpoint,
|
|
||||||
remoteICEEndpoint,
|
|
||||||
peerState.RelayAddress,
|
|
||||||
timeAgo(peerState.LastStatusUpdate),
|
|
||||||
timeAgo(peerState.LastWireguardHandshake),
|
|
||||||
toIEC(peerState.TransferReceived),
|
|
||||||
toIEC(peerState.TransferSent),
|
|
||||||
rosenpassEnabledStatus,
|
|
||||||
networks,
|
|
||||||
networks,
|
|
||||||
peerState.Latency.String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
peersString += peerString
|
|
||||||
}
|
|
||||||
return peersString
|
|
||||||
}
|
|
||||||
|
|
||||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|
||||||
statusEval := false
|
|
||||||
ipEval := false
|
|
||||||
nameEval := true
|
|
||||||
|
|
||||||
if statusFilter != "" {
|
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
|
||||||
if lowerStatusFilter == "disconnected" && isConnected {
|
|
||||||
statusEval = true
|
|
||||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
|
||||||
statusEval = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
|
||||||
_, ok := ipsFilterMap[peerState.IP]
|
|
||||||
if !ok {
|
|
||||||
ipEval = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(prefixNamesFilter) > 0 {
|
|
||||||
for prefixNameFilter := range prefixNamesFilterMap {
|
|
||||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
|
||||||
nameEval = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
nameEval = false
|
|
||||||
}
|
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
|
||||||
}
|
|
||||||
|
|
||||||
func toIEC(b int64) string {
|
|
||||||
const unit = 1024
|
|
||||||
if b < unit {
|
|
||||||
return fmt.Sprintf("%d B", b)
|
|
||||||
}
|
|
||||||
div, exp := int64(unit), 0
|
|
||||||
for n := b / unit; n >= unit; n /= unit {
|
|
||||||
div *= unit
|
|
||||||
exp++
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%.1f %ciB",
|
|
||||||
float64(b)/float64(div), "KMGTPE"[exp])
|
|
||||||
}
|
|
||||||
|
|
||||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
|
||||||
count := 0
|
|
||||||
for _, server := range dnsServers {
|
|
||||||
if server.Enabled {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
|
||||||
func timeAgo(t time.Time) string {
|
|
||||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
|
||||||
return "-"
|
|
||||||
}
|
|
||||||
duration := time.Since(t)
|
|
||||||
switch {
|
|
||||||
case duration < time.Second:
|
|
||||||
return "Now"
|
|
||||||
case duration < time.Minute:
|
|
||||||
seconds := int(duration.Seconds())
|
|
||||||
if seconds == 1 {
|
|
||||||
return "1 second ago"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d seconds ago", seconds)
|
|
||||||
case duration < time.Hour:
|
|
||||||
minutes := int(duration.Minutes())
|
|
||||||
seconds := int(duration.Seconds()) % 60
|
|
||||||
if minutes == 1 {
|
|
||||||
if seconds == 1 {
|
|
||||||
return "1 minute, 1 second ago"
|
|
||||||
} else if seconds > 0 {
|
|
||||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
|
||||||
}
|
|
||||||
return "1 minute ago"
|
|
||||||
}
|
|
||||||
if seconds > 0 {
|
|
||||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d minutes ago", minutes)
|
|
||||||
case duration < 24*time.Hour:
|
|
||||||
hours := int(duration.Hours())
|
|
||||||
minutes := int(duration.Minutes()) % 60
|
|
||||||
if hours == 1 {
|
|
||||||
if minutes == 1 {
|
|
||||||
return "1 hour, 1 minute ago"
|
|
||||||
} else if minutes > 0 {
|
|
||||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
|
||||||
}
|
|
||||||
return "1 hour ago"
|
|
||||||
}
|
|
||||||
if minutes > 0 {
|
|
||||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d hours ago", hours)
|
|
||||||
}
|
|
||||||
|
|
||||||
days := int(duration.Hours()) / 24
|
|
||||||
hours := int(duration.Hours()) % 24
|
|
||||||
if days == 1 {
|
|
||||||
if hours == 1 {
|
|
||||||
return "1 day, 1 hour ago"
|
|
||||||
} else if hours > 0 {
|
|
||||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
|
||||||
}
|
|
||||||
return "1 day ago"
|
|
||||||
}
|
|
||||||
if hours > 0 {
|
|
||||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d days ago", days)
|
|
||||||
}
|
|
||||||
|
|
||||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|
||||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
|
||||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
|
||||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
|
||||||
}
|
|
||||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
|
||||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
|
||||||
|
|
||||||
for i, route := range peer.Networks {
|
|
||||||
peer.Networks[i] = a.AnonymizeIPString(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range peer.Networks {
|
|
||||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
|
||||||
for i, peer := range overview.Peers.Details {
|
|
||||||
peer := peer
|
|
||||||
anonymizePeerDetail(a, &peer)
|
|
||||||
overview.Peers.Details[i] = peer
|
|
||||||
}
|
|
||||||
|
|
||||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
|
||||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
|
||||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
|
||||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
|
||||||
|
|
||||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
|
||||||
for i, detail := range overview.Relays.Details {
|
|
||||||
detail.URI = a.AnonymizeURI(detail.URI)
|
|
||||||
detail.Error = a.AnonymizeString(detail.Error)
|
|
||||||
overview.Relays.Details[i] = detail
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, nsGroup := range overview.NSServerGroups {
|
|
||||||
for j, domain := range nsGroup.Domains {
|
|
||||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
|
||||||
}
|
|
||||||
for j, ns := range nsGroup.Servers {
|
|
||||||
host, port, err := net.SplitHostPort(ns)
|
|
||||||
if err == nil {
|
|
||||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range overview.Networks {
|
|
||||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range overview.Routes {
|
|
||||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
|
||||||
}
|
|
||||||
|
@ -1,597 +1,11 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
loc, err := time.LoadLocation("UTC")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Local = loc
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp = &proto.StatusResponse{
|
|
||||||
Status: "Connected",
|
|
||||||
FullStatus: &proto.FullStatus{
|
|
||||||
Peers: []*proto.PeerState{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
PubKey: "Pubkey1",
|
|
||||||
Fqdn: "peer-1.awesome-domain.com",
|
|
||||||
ConnStatus: "Connected",
|
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
|
||||||
Relayed: false,
|
|
||||||
LocalIceCandidateType: "",
|
|
||||||
RemoteIceCandidateType: "",
|
|
||||||
LocalIceCandidateEndpoint: "",
|
|
||||||
RemoteIceCandidateEndpoint: "",
|
|
||||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
|
||||||
BytesRx: 200,
|
|
||||||
BytesTx: 100,
|
|
||||||
Networks: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
PubKey: "Pubkey2",
|
|
||||||
Fqdn: "peer-2.awesome-domain.com",
|
|
||||||
ConnStatus: "Connected",
|
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
|
||||||
Relayed: true,
|
|
||||||
LocalIceCandidateType: "relay",
|
|
||||||
RemoteIceCandidateType: "prflx",
|
|
||||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
|
||||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
|
||||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
|
||||||
BytesRx: 2000,
|
|
||||||
BytesTx: 1000,
|
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ManagementState: &proto.ManagementState{
|
|
||||||
URL: "my-awesome-management.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
SignalState: &proto.SignalState{
|
|
||||||
URL: "my-awesome-signal.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
Relays: []*proto.RelayState{
|
|
||||||
{
|
|
||||||
URI: "stun:my-awesome-stun.com:3478",
|
|
||||||
Available: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
Available: false,
|
|
||||||
Error: "context: deadline exceeded",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LocalPeerState: &proto.LocalPeerState{
|
|
||||||
IP: "192.168.178.100/16",
|
|
||||||
PubKey: "Some-Pub-Key",
|
|
||||||
KernelInterface: true,
|
|
||||||
Fqdn: "some-localhost.awesome-domain.com",
|
|
||||||
Networks: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DnsServers: []*proto.NSGroupState{
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"8.8.8.8:53",
|
|
||||||
},
|
|
||||||
Domains: nil,
|
|
||||||
Enabled: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53",
|
|
||||||
},
|
|
||||||
Domains: []string{
|
|
||||||
"example.com",
|
|
||||||
"example.net",
|
|
||||||
},
|
|
||||||
Enabled: false,
|
|
||||||
Error: "timeout",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DaemonVersion: "0.14.1",
|
|
||||||
}
|
|
||||||
|
|
||||||
var overview = statusOutputOverview{
|
|
||||||
Peers: peersStateOutput{
|
|
||||||
Total: 2,
|
|
||||||
Connected: 2,
|
|
||||||
Details: []peerStateDetailOutput{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
PubKey: "Pubkey1",
|
|
||||||
FQDN: "peer-1.awesome-domain.com",
|
|
||||||
Status: "Connected",
|
|
||||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
|
||||||
ConnType: "P2P",
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: "",
|
|
||||||
Remote: "",
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: "",
|
|
||||||
Remote: "",
|
|
||||||
},
|
|
||||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
|
||||||
TransferReceived: 200,
|
|
||||||
TransferSent: 100,
|
|
||||||
Routes: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Networks: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Latency: time.Duration(10000000),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
PubKey: "Pubkey2",
|
|
||||||
FQDN: "peer-2.awesome-domain.com",
|
|
||||||
Status: "Connected",
|
|
||||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
|
||||||
ConnType: "Relayed",
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: "relay",
|
|
||||||
Remote: "prflx",
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: "10.0.0.1:10001",
|
|
||||||
Remote: "10.0.10.1:10002",
|
|
||||||
},
|
|
||||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
|
||||||
TransferReceived: 2000,
|
|
||||||
TransferSent: 1000,
|
|
||||||
Latency: time.Duration(10000000),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
DaemonVersion: "0.14.1",
|
|
||||||
ManagementState: managementStateOutput{
|
|
||||||
URL: "my-awesome-management.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
SignalState: signalStateOutput{
|
|
||||||
URL: "my-awesome-signal.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
Relays: relayStateOutput{
|
|
||||||
Total: 2,
|
|
||||||
Available: 1,
|
|
||||||
Details: []relayStateOutputDetail{
|
|
||||||
{
|
|
||||||
URI: "stun:my-awesome-stun.com:3478",
|
|
||||||
Available: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
Available: false,
|
|
||||||
Error: "context: deadline exceeded",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IP: "192.168.178.100/16",
|
|
||||||
PubKey: "Some-Pub-Key",
|
|
||||||
KernelInterface: true,
|
|
||||||
FQDN: "some-localhost.awesome-domain.com",
|
|
||||||
NSServerGroups: []nsServerGroupStateOutput{
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"8.8.8.8:53",
|
|
||||||
},
|
|
||||||
Domains: nil,
|
|
||||||
Enabled: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53",
|
|
||||||
},
|
|
||||||
Domains: []string{
|
|
||||||
"example.com",
|
|
||||||
"example.net",
|
|
||||||
},
|
|
||||||
Enabled: false,
|
|
||||||
Error: "timeout",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Routes: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
Networks: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
|
||||||
convertedResult := convertToStatusOutputOverview(resp)
|
|
||||||
|
|
||||||
assert.Equal(t, overview, convertedResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSortingOfPeers(t *testing.T) {
|
|
||||||
peers := []peerStateDetailOutput{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.104",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.105",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.103",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
sortPeersByIP(peers)
|
|
||||||
|
|
||||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToJSON(t *testing.T) {
|
|
||||||
jsonString, _ := parseToJSON(overview)
|
|
||||||
|
|
||||||
//@formatter:off
|
|
||||||
expectedJSONString := `
|
|
||||||
{
|
|
||||||
"peers": {
|
|
||||||
"total": 2,
|
|
||||||
"connected": 2,
|
|
||||||
"details": [
|
|
||||||
{
|
|
||||||
"fqdn": "peer-1.awesome-domain.com",
|
|
||||||
"netbirdIp": "192.168.178.101",
|
|
||||||
"publicKey": "Pubkey1",
|
|
||||||
"status": "Connected",
|
|
||||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
|
||||||
"connectionType": "P2P",
|
|
||||||
"iceCandidateType": {
|
|
||||||
"local": "",
|
|
||||||
"remote": ""
|
|
||||||
},
|
|
||||||
"iceCandidateEndpoint": {
|
|
||||||
"local": "",
|
|
||||||
"remote": ""
|
|
||||||
},
|
|
||||||
"relayAddress": "",
|
|
||||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
|
||||||
"transferReceived": 200,
|
|
||||||
"transferSent": 100,
|
|
||||||
"latency": 10000000,
|
|
||||||
"quantumResistance": false,
|
|
||||||
"routes": [
|
|
||||||
"10.1.0.0/24"
|
|
||||||
],
|
|
||||||
"networks": [
|
|
||||||
"10.1.0.0/24"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"fqdn": "peer-2.awesome-domain.com",
|
|
||||||
"netbirdIp": "192.168.178.102",
|
|
||||||
"publicKey": "Pubkey2",
|
|
||||||
"status": "Connected",
|
|
||||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
|
||||||
"connectionType": "Relayed",
|
|
||||||
"iceCandidateType": {
|
|
||||||
"local": "relay",
|
|
||||||
"remote": "prflx"
|
|
||||||
},
|
|
||||||
"iceCandidateEndpoint": {
|
|
||||||
"local": "10.0.0.1:10001",
|
|
||||||
"remote": "10.0.10.1:10002"
|
|
||||||
},
|
|
||||||
"relayAddress": "",
|
|
||||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
|
||||||
"transferReceived": 2000,
|
|
||||||
"transferSent": 1000,
|
|
||||||
"latency": 10000000,
|
|
||||||
"quantumResistance": false,
|
|
||||||
"routes": null,
|
|
||||||
"networks": null
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"cliVersion": "development",
|
|
||||||
"daemonVersion": "0.14.1",
|
|
||||||
"management": {
|
|
||||||
"url": "my-awesome-management.com:443",
|
|
||||||
"connected": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
"signal": {
|
|
||||||
"url": "my-awesome-signal.com:443",
|
|
||||||
"connected": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
"relays": {
|
|
||||||
"total": 2,
|
|
||||||
"available": 1,
|
|
||||||
"details": [
|
|
||||||
{
|
|
||||||
"uri": "stun:my-awesome-stun.com:3478",
|
|
||||||
"available": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
"available": false,
|
|
||||||
"error": "context: deadline exceeded"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"netbirdIp": "192.168.178.100/16",
|
|
||||||
"publicKey": "Some-Pub-Key",
|
|
||||||
"usesKernelInterface": true,
|
|
||||||
"fqdn": "some-localhost.awesome-domain.com",
|
|
||||||
"quantumResistance": false,
|
|
||||||
"quantumResistancePermissive": false,
|
|
||||||
"routes": [
|
|
||||||
"10.10.0.0/24"
|
|
||||||
],
|
|
||||||
"networks": [
|
|
||||||
"10.10.0.0/24"
|
|
||||||
],
|
|
||||||
"dnsServers": [
|
|
||||||
{
|
|
||||||
"servers": [
|
|
||||||
"8.8.8.8:53"
|
|
||||||
],
|
|
||||||
"domains": null,
|
|
||||||
"enabled": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"servers": [
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53"
|
|
||||||
],
|
|
||||||
"domains": [
|
|
||||||
"example.com",
|
|
||||||
"example.net"
|
|
||||||
],
|
|
||||||
"enabled": false,
|
|
||||||
"error": "timeout"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`
|
|
||||||
// @formatter:on
|
|
||||||
|
|
||||||
var expectedJSON bytes.Buffer
|
|
||||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
|
||||||
|
|
||||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToYAML(t *testing.T) {
|
|
||||||
yaml, _ := parseToYAML(overview)
|
|
||||||
|
|
||||||
expectedYAML :=
|
|
||||||
`peers:
|
|
||||||
total: 2
|
|
||||||
connected: 2
|
|
||||||
details:
|
|
||||||
- fqdn: peer-1.awesome-domain.com
|
|
||||||
netbirdIp: 192.168.178.101
|
|
||||||
publicKey: Pubkey1
|
|
||||||
status: Connected
|
|
||||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
|
||||||
connectionType: P2P
|
|
||||||
iceCandidateType:
|
|
||||||
local: ""
|
|
||||||
remote: ""
|
|
||||||
iceCandidateEndpoint:
|
|
||||||
local: ""
|
|
||||||
remote: ""
|
|
||||||
relayAddress: ""
|
|
||||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
|
||||||
transferReceived: 200
|
|
||||||
transferSent: 100
|
|
||||||
latency: 10ms
|
|
||||||
quantumResistance: false
|
|
||||||
routes:
|
|
||||||
- 10.1.0.0/24
|
|
||||||
networks:
|
|
||||||
- 10.1.0.0/24
|
|
||||||
- fqdn: peer-2.awesome-domain.com
|
|
||||||
netbirdIp: 192.168.178.102
|
|
||||||
publicKey: Pubkey2
|
|
||||||
status: Connected
|
|
||||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
|
||||||
connectionType: Relayed
|
|
||||||
iceCandidateType:
|
|
||||||
local: relay
|
|
||||||
remote: prflx
|
|
||||||
iceCandidateEndpoint:
|
|
||||||
local: 10.0.0.1:10001
|
|
||||||
remote: 10.0.10.1:10002
|
|
||||||
relayAddress: ""
|
|
||||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
|
||||||
transferReceived: 2000
|
|
||||||
transferSent: 1000
|
|
||||||
latency: 10ms
|
|
||||||
quantumResistance: false
|
|
||||||
routes: []
|
|
||||||
networks: []
|
|
||||||
cliVersion: development
|
|
||||||
daemonVersion: 0.14.1
|
|
||||||
management:
|
|
||||||
url: my-awesome-management.com:443
|
|
||||||
connected: true
|
|
||||||
error: ""
|
|
||||||
signal:
|
|
||||||
url: my-awesome-signal.com:443
|
|
||||||
connected: true
|
|
||||||
error: ""
|
|
||||||
relays:
|
|
||||||
total: 2
|
|
||||||
available: 1
|
|
||||||
details:
|
|
||||||
- uri: stun:my-awesome-stun.com:3478
|
|
||||||
available: true
|
|
||||||
error: ""
|
|
||||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
|
||||||
available: false
|
|
||||||
error: 'context: deadline exceeded'
|
|
||||||
netbirdIp: 192.168.178.100/16
|
|
||||||
publicKey: Some-Pub-Key
|
|
||||||
usesKernelInterface: true
|
|
||||||
fqdn: some-localhost.awesome-domain.com
|
|
||||||
quantumResistance: false
|
|
||||||
quantumResistancePermissive: false
|
|
||||||
routes:
|
|
||||||
- 10.10.0.0/24
|
|
||||||
networks:
|
|
||||||
- 10.10.0.0/24
|
|
||||||
dnsServers:
|
|
||||||
- servers:
|
|
||||||
- 8.8.8.8:53
|
|
||||||
domains: []
|
|
||||||
enabled: true
|
|
||||||
error: ""
|
|
||||||
- servers:
|
|
||||||
- 1.1.1.1:53
|
|
||||||
- 2.2.2.2:53
|
|
||||||
domains:
|
|
||||||
- example.com
|
|
||||||
- example.net
|
|
||||||
enabled: false
|
|
||||||
error: timeout
|
|
||||||
`
|
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToDetail(t *testing.T) {
|
|
||||||
// Calculate time ago based on the fixture dates
|
|
||||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
|
||||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
|
||||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
|
||||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
|
||||||
|
|
||||||
detail := parseToFullDetailSummary(overview)
|
|
||||||
|
|
||||||
expectedDetail := fmt.Sprintf(
|
|
||||||
`Peers detail:
|
|
||||||
peer-1.awesome-domain.com:
|
|
||||||
NetBird IP: 192.168.178.101
|
|
||||||
Public key: Pubkey1
|
|
||||||
Status: Connected
|
|
||||||
-- detail --
|
|
||||||
Connection type: P2P
|
|
||||||
ICE candidate (Local/Remote): -/-
|
|
||||||
ICE candidate endpoints (Local/Remote): -/-
|
|
||||||
Relay server address:
|
|
||||||
Last connection update: %s
|
|
||||||
Last WireGuard handshake: %s
|
|
||||||
Transfer status (received/sent) 200 B/100 B
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.1.0.0/24
|
|
||||||
Networks: 10.1.0.0/24
|
|
||||||
Latency: 10ms
|
|
||||||
|
|
||||||
peer-2.awesome-domain.com:
|
|
||||||
NetBird IP: 192.168.178.102
|
|
||||||
Public key: Pubkey2
|
|
||||||
Status: Connected
|
|
||||||
-- detail --
|
|
||||||
Connection type: Relayed
|
|
||||||
ICE candidate (Local/Remote): relay/prflx
|
|
||||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
|
||||||
Relay server address:
|
|
||||||
Last connection update: %s
|
|
||||||
Last WireGuard handshake: %s
|
|
||||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: -
|
|
||||||
Networks: -
|
|
||||||
Latency: 10ms
|
|
||||||
|
|
||||||
OS: %s/%s
|
|
||||||
Daemon version: 0.14.1
|
|
||||||
CLI version: %s
|
|
||||||
Management: Connected to my-awesome-management.com:443
|
|
||||||
Signal: Connected to my-awesome-signal.com:443
|
|
||||||
Relays:
|
|
||||||
[stun:my-awesome-stun.com:3478] is Available
|
|
||||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
|
||||||
Nameservers:
|
|
||||||
[8.8.8.8:53] for [.] is Available
|
|
||||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
|
||||||
FQDN: some-localhost.awesome-domain.com
|
|
||||||
NetBird IP: 192.168.178.100/16
|
|
||||||
Interface type: Kernel
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Networks: 10.10.0.0/24
|
|
||||||
Peers count: 2/2 Connected
|
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
|
||||||
|
|
||||||
assert.Equal(t, expectedDetail, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToShortVersion(t *testing.T) {
|
|
||||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
|
||||||
|
|
||||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
|
||||||
Daemon version: 0.14.1
|
|
||||||
CLI version: development
|
|
||||||
Management: Connected
|
|
||||||
Signal: Connected
|
|
||||||
Relays: 1/2 Available
|
|
||||||
Nameservers: 1/2 Available
|
|
||||||
FQDN: some-localhost.awesome-domain.com
|
|
||||||
NetBird IP: 192.168.178.100/16
|
|
||||||
Interface type: Kernel
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Networks: 10.10.0.0/24
|
|
||||||
Peers count: 2/2 Connected
|
|
||||||
`
|
|
||||||
|
|
||||||
assert.Equal(t, expectedString, shortVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingOfIP(t *testing.T) {
|
func TestParsingOfIP(t *testing.T) {
|
||||||
InterfaceIP := "192.168.178.123/16"
|
InterfaceIP := "192.168.178.123/16"
|
||||||
|
|
||||||
@ -599,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeAgo(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
input time.Time
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"Now", now, "Now"},
|
|
||||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
|
||||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
|
||||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
|
||||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
|
||||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
|
||||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
|
||||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
|
||||||
{"Zero time", time.Time{}, "-"},
|
|
||||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result := timeAgo(tc.input)
|
|
||||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var traceCmd = &cobra.Command{
|
||||||
|
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||||
|
Short: "Trace a packet through the firewall",
|
||||||
|
Example: `
|
||||||
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||||
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
|
Args: cobra.ExactArgs(3),
|
||||||
|
RunE: tracePacket,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
debugCmd.AddCommand(traceCmd)
|
||||||
|
|
||||||
|
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||||
|
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||||
|
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||||
|
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||||
|
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||||
|
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||||
|
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||||
|
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||||
|
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||||
|
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||||
|
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||||
|
direction := strings.ToLower(args[0])
|
||||||
|
if direction != "in" && direction != "out" {
|
||||||
|
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := cmd.Flag("protocol").Value.String()
|
||||||
|
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||||
|
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||||
|
}
|
||||||
|
|
||||||
|
sport, err := cmd.Flags().GetUint16("sport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid source port: %v", err)
|
||||||
|
}
|
||||||
|
dport, err := cmd.Flags().GetUint16("dport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||||
|
if protocol != "icmp" {
|
||||||
|
if sport == 0 {
|
||||||
|
sport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
if dport == 0 {
|
||||||
|
dport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpFlags *proto.TCPFlags
|
||||||
|
if protocol == "tcp" {
|
||||||
|
syn, _ := cmd.Flags().GetBool("syn")
|
||||||
|
ack, _ := cmd.Flags().GetBool("ack")
|
||||||
|
fin, _ := cmd.Flags().GetBool("fin")
|
||||||
|
rst, _ := cmd.Flags().GetBool("rst")
|
||||||
|
psh, _ := cmd.Flags().GetBool("psh")
|
||||||
|
urg, _ := cmd.Flags().GetBool("urg")
|
||||||
|
|
||||||
|
tcpFlags = &proto.TCPFlags{
|
||||||
|
Syn: syn,
|
||||||
|
Ack: ack,
|
||||||
|
Fin: fin,
|
||||||
|
Rst: rst,
|
||||||
|
Psh: psh,
|
||||||
|
Urg: urg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||||
|
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||||
|
SourceIp: args[1],
|
||||||
|
DestinationIp: args[2],
|
||||||
|
Protocol: protocol,
|
||||||
|
SourcePort: uint32(sport),
|
||||||
|
DestinationPort: uint32(dport),
|
||||||
|
Direction: direction,
|
||||||
|
TcpFlags: tcpFlags,
|
||||||
|
IcmpType: &icmpType,
|
||||||
|
IcmpCode: &icmpCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
|
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
|
for _, stage := range resp.Stages {
|
||||||
|
if stage.ForwardingDetails != nil {
|
||||||
|
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := map[bool]string{
|
||||||
|
true: "\033[32mALLOWED\033[0m", // Green
|
||||||
|
false: "\033[31mDENIED\033[0m", // Red
|
||||||
|
}[resp.FinalDisposition]
|
||||||
|
|
||||||
|
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||||
|
}
|
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,8 +30,15 @@ const (
|
|||||||
interfaceInputType
|
interfaceInputType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dnsLabelsFlag = "extra-dns-labels"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
foregroundMode bool
|
foregroundMode bool
|
||||||
|
dnsLabels []string
|
||||||
|
dnsLabelsValidated domain.List
|
||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
Short: "install, login and start Netbird client",
|
Short: "install, login and start Netbird client",
|
||||||
@ -48,6 +56,15 @@ func init() {
|
|||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||||
|
`Sets DNS labels`+
|
||||||
|
`You can specify a comma-separated list of up to 32 labels. `+
|
||||||
|
`An empty string "" clears the previous configuration. `+
|
||||||
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
|
`or --extra-dns-labels ""`,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@ -66,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
if hostName != "" {
|
if hostName != "" {
|
||||||
@ -97,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||||
|
DNSLabels: dnsLabelsValidated,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
@ -160,6 +183,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.DisableFirewall = &disableFirewall
|
ic.DisableFirewall = &disableFirewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -185,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
return connectClient.Run()
|
return connectClient.Run(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
@ -235,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
|
DnsLabels: dnsLabels,
|
||||||
|
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@ -290,6 +319,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.DisableFirewall = &disableFirewall
|
loginRequest.DisableFirewall = &disableFirewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
@ -421,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
|||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateDnsLabels(labels []string) (domain.List, error) {
|
||||||
|
var (
|
||||||
|
domains domain.List
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(labels) == 0 {
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err = domain.ValidateDomains(labels)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
func isValidAddrPort(input string) bool {
|
func isValidAddrPort(input string) bool {
|
||||||
if input == "" {
|
if input == "" {
|
||||||
return true
|
return true
|
||||||
|
167
client/embed/doc.go
Normal file
167
client/embed/doc.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
// Package embed provides a way to embed the NetBird client directly
|
||||||
|
// into Go programs without requiring a separate NetBird client installation.
|
||||||
|
package embed
|
||||||
|
|
||||||
|
// Basic Usage:
|
||||||
|
//
|
||||||
|
// client, err := embed.New(embed.Options{
|
||||||
|
// DeviceName: "my-service",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Complete HTTP Server Example:
|
||||||
|
//
|
||||||
|
// package main
|
||||||
|
//
|
||||||
|
// import (
|
||||||
|
// "context"
|
||||||
|
// "fmt"
|
||||||
|
// "log"
|
||||||
|
// "net/http"
|
||||||
|
// "os"
|
||||||
|
// "os/signal"
|
||||||
|
// "syscall"
|
||||||
|
// "time"
|
||||||
|
//
|
||||||
|
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// // Create client with setup key and device name
|
||||||
|
// client, err := netbird.New(netbird.Options{
|
||||||
|
// DeviceName: "http-server",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// LogOutput: io.Discard,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Start with timeout
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create HTTP server
|
||||||
|
// mux := http.NewServeMux()
|
||||||
|
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
|
||||||
|
// fmt.Fprintf(w, "Hello from netbird!")
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Listen on netbird network
|
||||||
|
// l, err := client.ListenTCP(":8080")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// server := &http.Server{Handler: mux}
|
||||||
|
// go func() {
|
||||||
|
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
// log.Printf("HTTP server error: %v", err)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
//
|
||||||
|
// log.Printf("HTTP server listening on netbird network port 8080")
|
||||||
|
//
|
||||||
|
// // Handle shutdown
|
||||||
|
// stop := make(chan os.Signal, 1)
|
||||||
|
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
// <-stop
|
||||||
|
//
|
||||||
|
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("HTTP shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// if err := client.Stop(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("Netbird shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Complete HTTP Client Example:
|
||||||
|
//
|
||||||
|
// package main
|
||||||
|
//
|
||||||
|
// import (
|
||||||
|
// "context"
|
||||||
|
// "fmt"
|
||||||
|
// "io"
|
||||||
|
// "log"
|
||||||
|
// "os"
|
||||||
|
// "time"
|
||||||
|
//
|
||||||
|
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// // Create client with setup key and device name
|
||||||
|
// client, err := netbird.New(netbird.Options{
|
||||||
|
// DeviceName: "http-client",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// LogOutput: io.Discard,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Start with timeout
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create HTTP client that uses netbird network
|
||||||
|
// httpClient := client.NewHTTPClient()
|
||||||
|
// httpClient.Timeout = 10 * time.Second
|
||||||
|
//
|
||||||
|
// // Make request to server in netbird network
|
||||||
|
// target := os.Getenv("NB_TARGET")
|
||||||
|
// resp, err := httpClient.Get(target)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
//
|
||||||
|
// // Read and print response
|
||||||
|
// body, err := io.ReadAll(resp.Body)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fmt.Printf("Response from server: %s\n", string(body))
|
||||||
|
//
|
||||||
|
// // Clean shutdown
|
||||||
|
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := client.Stop(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("Netbird shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The package provides several methods for network operations:
|
||||||
|
// - Dial: Creates outbound connections
|
||||||
|
// - ListenTCP: Creates TCP listeners
|
||||||
|
// - ListenUDP: Creates UDP listeners
|
||||||
|
//
|
||||||
|
// By default, the embed package uses userspace networking mode, which doesn't
|
||||||
|
// require root/admin privileges. For production deployments, consider setting
|
||||||
|
// appropriate config and state paths for persistence.
|
296
client/embed/embed.go
Normal file
296
client/embed/embed.go
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
package embed
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
|
var ErrClientNotStarted = errors.New("client not started")
|
||||||
|
|
||||||
|
// Client manages a netbird embedded client instance
|
||||||
|
type Client struct {
|
||||||
|
deviceName string
|
||||||
|
config *internal.Config
|
||||||
|
mu sync.Mutex
|
||||||
|
cancel context.CancelFunc
|
||||||
|
setupKey string
|
||||||
|
connect *internal.ConnectClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options configures a new Client
|
||||||
|
type Options struct {
|
||||||
|
// DeviceName is this peer's name in the network
|
||||||
|
DeviceName string
|
||||||
|
// SetupKey is used for authentication
|
||||||
|
SetupKey string
|
||||||
|
// ManagementURL overrides the default management server URL
|
||||||
|
ManagementURL string
|
||||||
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
|
PreSharedKey string
|
||||||
|
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||||
|
LogOutput io.Writer
|
||||||
|
// LogLevel sets the logging level (defaults to info if empty)
|
||||||
|
LogLevel string
|
||||||
|
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
|
||||||
|
NoUserspace bool
|
||||||
|
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
|
||||||
|
ConfigPath string
|
||||||
|
// StatePath is the path to the netbird state file
|
||||||
|
StatePath string
|
||||||
|
// DisableClientRoutes disables the client routes
|
||||||
|
DisableClientRoutes bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new netbird embedded client
|
||||||
|
func New(opts Options) (*Client, error) {
|
||||||
|
if opts.LogOutput != nil {
|
||||||
|
logrus.SetOutput(opts.LogOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.LogLevel != "" {
|
||||||
|
level, err := logrus.ParseLevel(opts.LogLevel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse log level: %w", err)
|
||||||
|
}
|
||||||
|
logrus.SetLevel(level)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.NoUserspace {
|
||||||
|
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.StatePath != "" {
|
||||||
|
// TODO: Disable state if path not provided
|
||||||
|
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t := true
|
||||||
|
var config *internal.Config
|
||||||
|
var err error
|
||||||
|
input := internal.ConfigInput{
|
||||||
|
ConfigPath: opts.ConfigPath,
|
||||||
|
ManagementURL: opts.ManagementURL,
|
||||||
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
|
DisableServerRoutes: &t,
|
||||||
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
}
|
||||||
|
if opts.ConfigPath != "" {
|
||||||
|
config, err = internal.UpdateOrCreateConfig(input)
|
||||||
|
} else {
|
||||||
|
config, err = internal.CreateInMemoryConfig(input)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
deviceName: opts.DeviceName,
|
||||||
|
setupKey: opts.SetupKey,
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
|
||||||
|
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
|
||||||
|
func (c *Client) Start(startCtx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.cancel != nil {
|
||||||
|
return ErrClientAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||||
|
return fmt.Errorf("login: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||||
|
|
||||||
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
|
// TODO: make after-startup backoff err available
|
||||||
|
run := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := client.Run(run); err != nil {
|
||||||
|
run <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-startCtx.Done():
|
||||||
|
if stopErr := client.Stop(); stopErr != nil {
|
||||||
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
|
}
|
||||||
|
return startCtx.Err()
|
||||||
|
case err := <-run:
|
||||||
|
if err != nil {
|
||||||
|
if stopErr := client.Stop(); stopErr != nil {
|
||||||
|
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("startup: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connect = client
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully stops the client.
|
||||||
|
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
|
||||||
|
func (c *Client) Stop(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.connect == nil {
|
||||||
|
return ErrClientNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- c.connect.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.cancel = nil
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-done:
|
||||||
|
c.cancel = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial dials a network address in the netbird network.
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
connect := c.connect
|
||||||
|
if connect == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, ErrClientNotStarted
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
engine := connect.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
nsnet, err := engine.GetNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get net: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenTCP listens on the given address in the netbird network
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
|
nsnet, addr, err := c.getNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("split host port: %w", err)
|
||||||
|
}
|
||||||
|
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||||
|
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve: %w", err)
|
||||||
|
}
|
||||||
|
return nsnet.ListenTCP(tcpAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenUDP listens on the given address in the netbird network
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||||
|
nsnet, addr, err := c.getNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("split host port: %w", err)
|
||||||
|
}
|
||||||
|
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet.ListenUDP(udpAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) NewHTTPClient() *http.Client {
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: c.Dial,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
connect := c.connect
|
||||||
|
if connect == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, netip.Addr{}, errors.New("client not started")
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
engine := connect.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, netip.Addr{}, errors.New("engine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := engine.Address()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nsnet, err := engine.GetNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet, addr, nil
|
||||||
|
}
|
@ -14,13 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface)
|
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, err := createNativeFirewall(iface, stateManager)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||||
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface)
|
fm, err := createFW(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,4 +12,6 @@ type IFaceMapper interface {
|
|||||||
Address() device.WGAddress
|
Address() device.WGAddress
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -20,7 +20,6 @@ const (
|
|||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var dPortVal, sPortVal string
|
chain := chainNameInputRules
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var chain string
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||||
chain = chainNameOutputRules
|
|
||||||
} else {
|
|
||||||
chain = chainNameInputRules
|
|
||||||
}
|
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
mangleSpecs := slices.Clone(specs)
|
||||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
mangleSpecs = append(mangleSpecs,
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
)
|
||||||
|
|
||||||
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
}
|
}
|
||||||
@ -145,13 +138,19 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("rule already exists")
|
return nil, fmt.Errorf("rule already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle rule: %v", err)
|
||||||
|
mangleSpecs = nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
ruleID: uuid.New().String(),
|
ruleID: uuid.New().String(),
|
||||||
specs: specs,
|
specs: specs,
|
||||||
|
mangleSpecs: mangleSpecs,
|
||||||
ipsetName: ipsetName,
|
ipsetName: ipsetName,
|
||||||
ip: ip.String(),
|
ip: ip.String(),
|
||||||
chain: chain,
|
chain: chain,
|
||||||
@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.mangleSpecs != nil {
|
||||||
|
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
|
|||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
func (m *aclManager) cleanChains() error {
|
func (m *aclManager) cleanChains() error {
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
rules := m.entries["OUTPUT"]
|
|
||||||
for _, rule := range rules {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to list chains: %s", err)
|
log.Debugf("failed to list chains: %s", err)
|
||||||
return err
|
return err
|
||||||
@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain netbird-acl-output-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
@ -329,8 +307,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
|
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
|
||||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
|
||||||
func (m *aclManager) seedInitialEntries() {
|
func (m *aclManager) seedInitialEntries() {
|
||||||
established := getConntrackEstablished()
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
@ -346,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||||
position: 2,
|
position: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m.optionalEntries["PREROUTING"] = []entry{
|
|
||||||
{
|
|
||||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
|
||||||
position: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
@ -390,16 +359,13 @@ func (m *aclManager) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
|
||||||
) (specs []string) {
|
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
if ip.String() == "0.0.0.0" {
|
if ip.String() == "0.0.0.0" {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
switch direction {
|
|
||||||
case firewall.RuleDirectionIN:
|
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
@ -407,25 +373,12 @@ func filterRuleSpecs(
|
|||||||
specs = append(specs, "-s", ip.String())
|
specs = append(specs, "-s", ip.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case firewall.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
specs = append(specs, "-p", protocol)
|
specs = append(specs, "-p", protocol)
|
||||||
}
|
}
|
||||||
if sPort != "" {
|
specs = append(specs, applyPort("--sport", sPort)...)
|
||||||
specs = append(specs, "--sport", sPort)
|
specs = append(specs, applyPort("--dport", dPort)...)
|
||||||
}
|
return specs
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
return append(specs, "-j", actionToStr(action))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
func actionToStr(action firewall.Action) string {
|
||||||
@ -435,15 +388,15 @@ func actionToStr(action firewall.Action) string {
|
|||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
|
||||||
switch {
|
switch {
|
||||||
case ipsetName == "":
|
case ipsetName == "":
|
||||||
return ""
|
return ""
|
||||||
case sPort != "" && dPort != "":
|
case sPort != nil && dPort != nil:
|
||||||
return ipsetName + "-sport-dport"
|
return ipsetName + "-sport-dport"
|
||||||
case sPort != "":
|
case sPort != nil:
|
||||||
return ipsetName + "-sport"
|
return ipsetName + "-sport"
|
||||||
case dPort != "":
|
case dPort != nil:
|
||||||
return ipsetName + "-dport"
|
return ipsetName + "-dport"
|
||||||
default:
|
default:
|
||||||
return ipsetName
|
return ipsetName
|
||||||
|
@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
_ string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.RuleDirectionIN,
|
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
@ -215,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{8043: 8046},
|
IsRange: true,
|
||||||
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Reset(nil)
|
||||||
@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManagerIPSet(t *testing.T) {
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mock := &iFaceMock{
|
mock := &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(
|
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
|
||||||
"default", "accept HTTPS traffic from ports range",
|
|
||||||
)
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule := genRouteFilteringRuleSpec(params)
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
var err error
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// after the established rule
|
||||||
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||||
|
} else {
|
||||||
|
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add route rule: %v", err)
|
return nil, fmt.Errorf("add route rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -590,10 +599,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
if len(port.Values) > 1 {
|
if len(port.Values) > 1 {
|
||||||
portList := make([]string, len(port.Values))
|
portList := make([]string, len(port.Values))
|
||||||
for i, p := range port.Values {
|
for i, p := range port.Values {
|
||||||
portList[i] = strconv.Itoa(p)
|
portList[i] = strconv.Itoa(int(p))
|
||||||
}
|
}
|
||||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
|
@ -6,6 +6,7 @@ type Rule struct {
|
|||||||
ipsetName string
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
|
mangleSpecs []string
|
||||||
ip string
|
ip string
|
||||||
chain string
|
chain string
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,6 @@ type Manager interface {
|
|||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
@ -100,6 +99,12 @@ type Manager interface {
|
|||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
|
SetLogLevel(log.Level)
|
||||||
|
|
||||||
|
EnableRouting() error
|
||||||
|
|
||||||
|
DisableRouting() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
@ -30,7 +30,7 @@ type Port struct {
|
|||||||
IsRange bool
|
IsRange bool
|
||||||
|
|
||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
Values []int
|
Values []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// String interface implementation
|
// String interface implementation
|
||||||
@ -40,7 +40,11 @@ func (p *Port) String() string {
|
|||||||
if ports != "" {
|
if ports != "" {
|
||||||
ports += ","
|
ports += ","
|
||||||
}
|
}
|
||||||
ports += strconv.Itoa(port)
|
ports += strconv.Itoa(int(port))
|
||||||
}
|
}
|
||||||
|
if p.IsRange {
|
||||||
|
ports = "range:" + ports
|
||||||
|
}
|
||||||
|
|
||||||
return ports
|
return ports
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,9 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -23,7 +23,6 @@ const (
|
|||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "netbird-acl-input-rules"
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
chainNameOutputRules = "netbird-acl-output-rules"
|
|
||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
@ -47,7 +46,7 @@ type AclManager struct {
|
|||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
chainOutputRules *nftables.Chain
|
chainPrerouting *nftables.Chain
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
@ -89,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
@ -104,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -121,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.nftSet == nil {
|
if r.nftSet == nil {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.GetRuleID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.GetRuleID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -156,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
err = m.rConn.Flush()
|
if r.mangleRule != nil {
|
||||||
if err != nil {
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
Exprs: expIn,
|
Exprs: expIn,
|
||||||
})
|
})
|
||||||
|
|
||||||
expOut := []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
// mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: []byte{0, 0, 0, 0},
|
|
||||||
Xor: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainOutputRules,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expOut,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@ -260,25 +239,33 @@ func (m *AclManager) Flush() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
|
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
}
|
}
|
||||||
|
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
func (m *AclManager) addIOFiltering(
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipset *nftables.Set,
|
||||||
|
comment string,
|
||||||
|
) (*Rule, error) {
|
||||||
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
r.nftRule,
|
nftRule: r.nftRule,
|
||||||
r.nftSet,
|
mangleRule: r.mangleRule,
|
||||||
r.ruleID,
|
nftSet: r.nftSet,
|
||||||
ip,
|
ruleID: r.ruleID,
|
||||||
|
ip: ip,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
addrOffset := uint32(12)
|
addrOffset := uint32(12)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
|
||||||
addrOffset += 4 // is ipv4 address length
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
@ -342,62 +326,35 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
expressions = append(expressions, applyPort(sPort, true)...)
|
||||||
expressions = append(expressions,
|
expressions = append(expressions, applyPort(dPort, false)...)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 0,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*sPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) != 0 {
|
mainExpressions := slices.Clone(expressions)
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*dPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case firewall.ActionAccept:
|
case firewall.ActionAccept:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||||
case firewall.ActionDrop:
|
case firewall.ActionDrop:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||||
|
|
||||||
var chain *nftables.Chain
|
chain := m.chainInputRules
|
||||||
if direction == firewall.RuleDirectionIN {
|
|
||||||
chain = m.chainInputRules
|
|
||||||
} else {
|
|
||||||
chain = m.chainOutputRules
|
|
||||||
}
|
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: mainExpressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
nftRule: nftRule,
|
nftRule: nftRule,
|
||||||
|
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||||
nftSet: ipset,
|
nftSet: ipset,
|
||||||
ruleID: ruleId,
|
ruleID: ruleId,
|
||||||
ip: ip,
|
ip: ip,
|
||||||
@ -406,9 +363,63 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
if ipset != nil {
|
if ipset != nil {
|
||||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||||
|
if m.chainPrerouting == nil {
|
||||||
|
log.Warn("prerouting chain is not created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingExprs := slices.Clone(expressions)
|
||||||
|
|
||||||
|
// interface
|
||||||
|
preroutingExprs = append([]expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}, preroutingExprs...)
|
||||||
|
|
||||||
|
// local destination and mark
|
||||||
|
preroutingExprs = append(preroutingExprs,
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: m.chainPrerouting,
|
||||||
|
Exprs: preroutingExprs,
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m *AclManager) createDefaultChains() (err error) {
|
func (m *AclManager) createDefaultChains() (err error) {
|
||||||
// chainNameInputRules
|
// chainNameInputRules
|
||||||
chain := m.createChain(chainNameInputRules)
|
chain := m.createChain(chainNameInputRules)
|
||||||
@ -419,15 +430,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
}
|
}
|
||||||
m.chainInputRules = chain
|
m.chainInputRules = chain
|
||||||
|
|
||||||
// chainNameOutputRules
|
|
||||||
chain = m.createChain(chainNameOutputRules)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.chainOutputRules = chain
|
|
||||||
|
|
||||||
// netbird-acl-input-filter
|
// netbird-acl-input-filter
|
||||||
// type filter hook input priority filter; policy accept;
|
// type filter hook input priority filter; policy accept;
|
||||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
@ -461,7 +463,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
@ -469,8 +471,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
})
|
||||||
|
|
||||||
m.addPreroutingRule(preroutingChain)
|
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
@ -480,43 +480,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
|
||||||
m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: preroutingChain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Fib{
|
|
||||||
Register: 1,
|
|
||||||
ResultADDRTYPE: true,
|
|
||||||
FlagDADDR: true,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
m.rConn.InsertRule(&nftables.Rule{
|
m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
@ -532,8 +495,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictAccept,
|
||||||
Chain: m.chainInputRules.Name,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -680,6 +642,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
for i := 0; ; i++ {
|
for i := 0; ; i++ {
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debugf("failed to flush nftables: %v", err)
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -696,7 +659,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||||
if m.workTable == nil || chain == nil {
|
if m.workTable == nil || chain == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -713,22 +676,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
|||||||
split := bytes.Split(rule.UserData, []byte(" "))
|
split := bytes.Split(rule.UserData, []byte(" "))
|
||||||
r, ok := m.rules[string(split[0])]
|
r, ok := m.rules[string(split[0])]
|
||||||
if ok {
|
if ok {
|
||||||
|
if mangle {
|
||||||
|
*r.mangleRule = *rule
|
||||||
|
} else {
|
||||||
*r.nftRule = *rule
|
*r.nftRule = *rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generatePeerRuleId(
|
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||||
ip net.IP,
|
rulesetID := ":"
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipset *nftables.Set,
|
|
||||||
) string {
|
|
||||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
|
||||||
if sPort != nil {
|
if sPort != nil {
|
||||||
rulesetID += sPort.String()
|
rulesetID += sPort.String()
|
||||||
}
|
}
|
||||||
@ -744,12 +704,6 @@ func generatePeerRuleId(
|
|||||||
return "set:" + ipset.Name + rulesetID
|
return "set:" + ipset.Name + rulesetID
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port firewall.Port) []byte {
|
|
||||||
bs := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
func ifname(n string) []byte {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
copy(b, n+"\x00")
|
copy(b, n+"\x00")
|
||||||
|
@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@ -312,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
//
|
//
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(
|
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{53}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionDrop,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@ -116,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(
|
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{80}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionAccept,
|
|
||||||
"",
|
|
||||||
"test rule",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
&fw.Port{Values: []int{443}},
|
&fw.Port{Values: []uint16{443}},
|
||||||
fw.ActionAccept,
|
fw.ActionAccept,
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add route filtering rule")
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
@ -329,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
stdout, stderr = runIptablesSave(t)
|
stdout, stderr = runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
|
t.Helper()
|
||||||
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||||
|
_, wantIsCounter := want[i].(*expr.Counter)
|
||||||
|
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
|
|||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// TODO: Insert after the established rule
|
||||||
|
rule = r.conn.InsertRule(rule)
|
||||||
|
} else {
|
||||||
rule = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
@ -956,12 +962,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpGte,
|
Op: expr.CmpOpGte,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpLte,
|
Op: expr.CmpOpLte,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -980,7 +986,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
exprs = append(exprs, &expr.Cmp{
|
exprs = append(exprs, &expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
Data: binaryutil.BigEndian.PutUint16(p),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
|
mangleRule *nftables.Rule
|
||||||
nftSet *nftables.Set
|
nftSet *nftables.Set
|
||||||
ruleID string
|
ruleID string
|
||||||
ip net.IP
|
ip net.IP
|
||||||
|
@ -3,6 +3,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
|
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
@ -15,7 +15,6 @@ type BaseConnTrack struct {
|
|||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestPort uint16
|
DestPort uint16
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||||
established atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
|
||||||
func (b *BaseConnTrack) IsEstablished() bool {
|
|
||||||
return b.established.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
|
||||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
|
||||||
b.established.Store(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
@ -3,8 +3,14 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
func BenchmarkIPOperations(b *testing.B) {
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
func BenchmarkAtomicOperations(b *testing.B) {
|
|
||||||
conn := &BaseConnTrack{}
|
|
||||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.IsEstablished()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("SetEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.SetEstablished(i%2 == 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("GetLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.GetLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
|||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
type ICMPTracker struct {
|
type ICMPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ICMPConnKey]*ICMPConnTrack
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -42,12 +45,13 @@ type ICMPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &ICMPTracker{
|
tracker := &ICMPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
|||||||
// TrackOutbound records an outbound ICMP Echo Request
|
// TrackOutbound records an outbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
|||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
Sequence: seq,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New ICMP connection %v", key)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
switch icmpType {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
|
||||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
|
||||||
return true
|
|
||||||
case uint8(layers.ICMPv4TypeEchoReply):
|
|
||||||
// continue processing
|
|
||||||
default:
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.ID == id &&
|
conn.ID == id &&
|
||||||
conn.Sequence == seq
|
conn.Sequence == seq
|
||||||
@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
@ -5,7 +5,10 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -62,11 +65,23 @@ type TCPConnKey struct {
|
|||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
State TCPState
|
||||||
|
established atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (t *TCPConnTrack) IsEstablished() bool {
|
||||||
|
return t.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||||
|
t.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*TCPConnTrack
|
connections map[ConnKey]*TCPConnTrack
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -76,8 +91,9 @@ type TCPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
|||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
// Create key before lock
|
// Create key before lock
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
},
|
},
|
||||||
State: TCPStateNew,
|
State: TCPStateNew,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(false)
|
conn.established.Store(false)
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, true)
|
t.updateState(conn, flags, true)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
conn.SetEstablished(false)
|
conn.SetEstablished(false)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
case TCPStateTimeWait:
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||||
// This is handled by the cleanup routine
|
// This is handled by the cleanup routine
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
|||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
|||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
type UDPTracker struct {
|
type UDPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*UDPConnTrack
|
connections map[ConnKey]*UDPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -29,12 +32,13 @@ type UDPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
|||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New UDP connection: %v", conn)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.DestPort == srcPort &&
|
conn.DestPort == srcPort &&
|
||||||
conn.SourcePort == dstPort
|
conn.SourcePort == dstPort
|
||||||
@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout)
|
tracker := NewUDPTracker(tt.timeout, logger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
|||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.True(t, conn.IsEstablished())
|
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1 * time.Second)
|
tracker := NewUDPTracker(1*time.Second, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
|
type endpoint struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
device *wgdevice.Device
|
||||||
|
mtu uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
var written int
|
||||||
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
|
if data == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet through WireGuard
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Wait() {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultReceiveWindow = 32768
|
||||||
|
defaultMaxInFlight = 1024
|
||||||
|
iosReceiveWindow = 16384
|
||||||
|
iosMaxInFlight = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *endpoint
|
||||||
|
udpForwarder *udpForwarder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ip net.IP
|
||||||
|
netstack bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||||
|
s := stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
icmp.NewProtocol4,
|
||||||
|
},
|
||||||
|
HandleLocal: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
mtu, err := iface.GetDevice().MTU()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get MTU: %w", err)
|
||||||
|
}
|
||||||
|
nicID := tcpip.NICID(1)
|
||||||
|
endpoint := &endpoint{
|
||||||
|
logger: logger,
|
||||||
|
device: iface.GetWGDevice(),
|
||||||
|
mtu: uint32(mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := iface.Address().Network.Mask.Size()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
|
PrefixLen: ones,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
|
}
|
||||||
|
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetRouteTable([]tcpip.Route{
|
||||||
|
{
|
||||||
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &Forwarder{
|
||||||
|
logger: logger,
|
||||||
|
stack: s,
|
||||||
|
endpoint: endpoint,
|
||||||
|
udpForwarder: newUDPForwarder(mtu, logger),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
netstack: netstack,
|
||||||
|
ip: iface.Address().IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
receiveWindow := defaultReceiveWindow
|
||||||
|
maxInFlight := defaultMaxInFlight
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
receiveWindow = iosReceiveWindow
|
||||||
|
maxInFlight = iosMaxInFlight
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||||
|
|
||||||
|
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if f.endpoint.dispatcher != nil {
|
||||||
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the forwarder
|
||||||
|
func (f *Forwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
if f.udpForwarder != nil {
|
||||||
|
f.udpForwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stack.Close()
|
||||||
|
f.stack.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
|
return net.IPv4(127, 0, 0, 1)
|
||||||
|
}
|
||||||
|
return addr.AsSlice()
|
||||||
|
}
|
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleICMP handles ICMP packets from the network stack
|
||||||
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
lc := net.ListenConfig{}
|
||||||
|
// TODO: support non-root
|
||||||
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||||
|
|
||||||
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
|
// Get the complete ICMP message (header + data)
|
||||||
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
|
||||||
|
// For Echo Requests, send and handle response
|
||||||
|
switch icmpHdr.Type() {
|
||||||
|
case header.ICMPv4Echo:
|
||||||
|
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||||
|
case header.ICMPv4EchoReply:
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||||
|
_, err = conn.WriteTo(payload, dst)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||||
|
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, f.endpoint.mtu)
|
||||||
|
n, _, err := conn.ReadFrom(response)
|
||||||
|
if err != nil {
|
||||||
|
if !isTimeout(err) {
|
||||||
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
|
ip := header.IPv4(ipHdr)
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||||
|
SrcAddr: id.LocalAddress,
|
||||||
|
DstAddr: id.RemoteAddress,
|
||||||
|
})
|
||||||
|
ip.SetChecksum(^ip.CalculateChecksum())
|
||||||
|
|
||||||
|
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||||
|
fullPacket = append(fullPacket, ipHdr...)
|
||||||
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||||
|
return true
|
||||||
|
}
|
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
|
if err != nil {
|
||||||
|
r.Complete(true)
|
||||||
|
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
r.Complete(false)
|
||||||
|
|
||||||
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||||
|
|
||||||
|
go f.proxyTCP(id, inConn, outConn, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
ep.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create context for managing the proxy goroutines
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(outConn, inConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(inConn, outConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
conn *gonet.UDPConn
|
||||||
|
outConn net.Conn
|
||||||
|
lastSeen atomic.Int64
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpForwarder struct {
|
||||||
|
sync.RWMutex
|
||||||
|
logger *nblog.Logger
|
||||||
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
|
bufPool sync.Pool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type idleConn struct {
|
||||||
|
id stack.TransportEndpointID
|
||||||
|
conn *udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &udpForwarder{
|
||||||
|
logger: logger,
|
||||||
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, mtu)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go f.cleanup()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the UDP forwarder and all active connections
|
||||||
|
func (f *udpForwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ep.Close()
|
||||||
|
delete(f.conns, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle UDP connections
|
||||||
|
func (f *udpForwarder) cleanup() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
var idleConns []idleConn
|
||||||
|
|
||||||
|
f.RLock()
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
if conn.getIdleDuration() > udpTimeout {
|
||||||
|
idleConns = append(idleConns, idleConn{id, conn})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.RUnlock()
|
||||||
|
|
||||||
|
for _, idle := range idleConns {
|
||||||
|
idle.conn.cancel()
|
||||||
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idle.conn.ep.Close()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
delete(f.conns, idle.id)
|
||||||
|
f.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
f.udpForwarder.RLock()
|
||||||
|
_, exists := f.udpForwarder.conns[id]
|
||||||
|
f.udpForwarder.RUnlock()
|
||||||
|
if exists {
|
||||||
|
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
// TODO: Send ICMP error message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
|
|
||||||
|
pConn := &udpPacketConn{
|
||||||
|
conn: inConn,
|
||||||
|
outConn: outConn,
|
||||||
|
cancel: connCancel,
|
||||||
|
ep: ep,
|
||||||
|
}
|
||||||
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
// Double-check no connection was created while we were setting up
|
||||||
|
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
pConn.cancel()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||||
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
pConn.cancel()
|
||||||
|
if err := pConn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, id)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||||
|
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||||
|
bufp := bufPool.Get().(*[]byte)
|
||||||
|
defer bufPool.Put(bufp)
|
||||||
|
buffer := *bufp
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if isTimeout(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read from %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.updateLastSeen()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedError(err error) bool {
|
||||||
|
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
134
client/firewall/uspfilter/localip.go
Normal file
134
client/firewall/uspfilter/localip.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type localIPManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||||
|
ipv4Bitmap [1 << 16]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalIPManager() *localIPManager {
|
||||||
|
return &localIPManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
if int(high) >= len(*newIPv4Bitmap) {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
ipStr := ip.String()
|
||||||
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
|
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
|
log.Debugf("process IP failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var newIPv4Bitmap [1 << 16]uint32
|
||||||
|
ipv4Set := make(map[string]struct{})
|
||||||
|
var ipv4Addresses []string
|
||||||
|
|
||||||
|
// 127.0.0.0/8
|
||||||
|
high := uint16(127) << 8
|
||||||
|
for i := uint16(0); i < 256; i++ {
|
||||||
|
newIPv4Bitmap[high|i] = 0xffffffff
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
return m.checkBitmapBit(ipv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
270
client/firewall/uspfilter/localip_test.go
Normal file
270
client/firewall/uspfilter/localip_test.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalIPManager(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupAddr iface.WGAddress
|
||||||
|
testIP net.IP
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Localhost range",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.2"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost standard address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost range edge",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.255.255.255"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP matches",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("fe80::1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("fe80::"),
|
||||||
|
Mask: net.CIDRMask(64, 128),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("fe80::1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
|
mock := &IFaceMock{
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return tt.setupAddr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := manager.IsLocalIP(tt.testIP)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
mock := &IFaceMock{}
|
||||||
|
|
||||||
|
// Get actual local interfaces
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tests []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all local interface IPs to test cases
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip4.String(),
|
||||||
|
expected: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some external IPs as negative test cases
|
||||||
|
externalIPs := []string{
|
||||||
|
"8.8.8.8",
|
||||||
|
"1.1.1.1",
|
||||||
|
"208.67.222.222",
|
||||||
|
}
|
||||||
|
for _, ip := range externalIPs {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip,
|
||||||
|
expected: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotEmpty(t, tests, "No test cases generated")
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||||
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap version
|
||||||
|
bitmapManager := &localIPManager{
|
||||||
|
ipv4Bitmap: [1 << 16]uint32{},
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
196
client/firewall/uspfilter/log/log.go
Normal file
196
client/firewall/uspfilter/log/log.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||||
|
maxMessageSize = 1024 * 2 // 2KB per message
|
||||||
|
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||||
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Level represents log severity
|
||||||
|
type Level uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelPanic Level = iota
|
||||||
|
LevelFatal
|
||||||
|
LevelError
|
||||||
|
LevelWarn
|
||||||
|
LevelInfo
|
||||||
|
LevelDebug
|
||||||
|
LevelTrace
|
||||||
|
)
|
||||||
|
|
||||||
|
var levelStrings = map[Level]string{
|
||||||
|
LevelPanic: "PANC",
|
||||||
|
LevelFatal: "FATL",
|
||||||
|
LevelError: "ERRO",
|
||||||
|
LevelWarn: "WARN",
|
||||||
|
LevelInfo: "INFO",
|
||||||
|
LevelDebug: "DEBG",
|
||||||
|
LevelTrace: "TRAC",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
buffer *ringBuffer
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Reusable buffer pool for formatting messages
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
|
l := &Logger{
|
||||||
|
output: logrusLogger.Out,
|
||||||
|
buffer: newRingBuffer(bufferSize),
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
// Pre-allocate buffer for message formatting
|
||||||
|
b := make([]byte, 0, maxMessageSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
|
l.level.Store(uint32(logrusLevel))
|
||||||
|
level := levelStrings[Level(logrusLevel)]
|
||||||
|
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||||
|
|
||||||
|
l.wg.Add(1)
|
||||||
|
go l.worker()
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) SetLevel(level Level) {
|
||||||
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
|
||||||
|
// Timestamp
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Level
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Message
|
||||||
|
if len(args) > 0 {
|
||||||
|
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||||
|
} else {
|
||||||
|
*buf = append(*buf, format...)
|
||||||
|
}
|
||||||
|
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
l.formatMessage(bufp, level, format, args...)
|
||||||
|
|
||||||
|
if len(*bufp) > maxMessageSize {
|
||||||
|
*bufp = (*bufp)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
_, _ = l.buffer.Write(*bufp)
|
||||||
|
|
||||||
|
l.bufPool.Put(bufp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
|
l.log(LevelError, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
|
l.log(LevelWarn, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
|
l.log(LevelInfo, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
|
l.log(LevelDebug, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
|
l.log(LevelTrace, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker periodically flushes the buffer
|
||||||
|
func (l *Logger) worker() {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buf := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.shutdown:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Read accumulated messages
|
||||||
|
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write batch
|
||||||
|
_, _ = l.output.Write(buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the logger
|
||||||
|
func (l *Logger) Stop(ctx context.Context) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
l.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// ringBuffer is a simple ring buffer implementation
|
||||||
|
type ringBuffer struct {
|
||||||
|
buf []byte
|
||||||
|
size int
|
||||||
|
r, w int64 // Read and write positions
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRingBuffer(size int) *ringBuffer {
|
||||||
|
return &ringBuffer{
|
||||||
|
buf: make([]byte, size),
|
||||||
|
size: size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(p) > r.size {
|
||||||
|
p = p[:r.size]
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(p)
|
||||||
|
|
||||||
|
// Write data, handling wrap-around
|
||||||
|
pos := int(r.w % int64(r.size))
|
||||||
|
writeLen := min(len(p), r.size-pos)
|
||||||
|
copy(r.buf[pos:], p[:writeLen])
|
||||||
|
|
||||||
|
// If we have more data and need to wrap around
|
||||||
|
if writeLen < len(p) {
|
||||||
|
copy(r.buf, p[writeLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update write position
|
||||||
|
r.w += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.w == r.r {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate available data accounting for wraparound
|
||||||
|
available := int(r.w - r.r)
|
||||||
|
if available < 0 {
|
||||||
|
available += r.size
|
||||||
|
}
|
||||||
|
available = min(available, r.size)
|
||||||
|
|
||||||
|
// Limit read to buffer size
|
||||||
|
toRead := min(available, len(p))
|
||||||
|
if toRead == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read data, handling wrap-around
|
||||||
|
pos := int(r.r % int64(r.size))
|
||||||
|
readLen := min(toRead, r.size-pos)
|
||||||
|
n = copy(p, r.buf[pos:pos+readLen])
|
||||||
|
|
||||||
|
// If we need more data and need to wrap around
|
||||||
|
if readLen < toRead {
|
||||||
|
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update read position
|
||||||
|
r.r += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
@ -2,22 +2,22 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type Rule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
ip net.IP
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
direction firewall.RuleDirection
|
sPort *firewall.Port
|
||||||
sPort uint16
|
dPort *firewall.Port
|
||||||
dPort uint16
|
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
comment string
|
||||||
|
|
||||||
@ -25,6 +25,21 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *PeerRule) GetRuleID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteRule struct {
|
||||||
|
id string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
func (r *RouteRule) GetRuleID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
390
client/firewall/uspfilter/tracer.go
Normal file
390
client/firewall/uspfilter/tracer.go
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketStage int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StageReceived PacketStage = iota
|
||||||
|
StageConntrack
|
||||||
|
StagePeerACL
|
||||||
|
StageRouting
|
||||||
|
StageRouteACL
|
||||||
|
StageForwarding
|
||||||
|
StageCompleted
|
||||||
|
)
|
||||||
|
|
||||||
|
const msgProcessingCompleted = "Processing completed"
|
||||||
|
|
||||||
|
func (s PacketStage) String() string {
|
||||||
|
return map[PacketStage]string{
|
||||||
|
StageReceived: "Received",
|
||||||
|
StageConntrack: "Connection Tracking",
|
||||||
|
StagePeerACL: "Peer ACL",
|
||||||
|
StageRouting: "Routing",
|
||||||
|
StageRouteACL: "Route ACL",
|
||||||
|
StageForwarding: "Forwarding",
|
||||||
|
StageCompleted: "Completed",
|
||||||
|
}[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwarderAction struct {
|
||||||
|
Action string
|
||||||
|
RemoteAddr string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceResult struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
Stage PacketStage
|
||||||
|
Message string
|
||||||
|
Allowed bool
|
||||||
|
ForwarderAction *ForwarderAction
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketTrace struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
DestinationIP net.IP
|
||||||
|
Protocol string
|
||||||
|
SourcePort uint16
|
||||||
|
DestinationPort uint16
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
Results []TraceResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPState struct {
|
||||||
|
SYN bool
|
||||||
|
ACK bool
|
||||||
|
FIN bool
|
||||||
|
RST bool
|
||||||
|
PSH bool
|
||||||
|
URG bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketBuilder struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
Protocol fw.Protocol
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
PayloadSize int
|
||||||
|
TCPState *TCPState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
ForwarderAction: action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
|
ip := p.buildIPLayer()
|
||||||
|
pktLayers := []gopacket.SerializableLayer{ip}
|
||||||
|
|
||||||
|
transportLayer, err := p.buildTransportLayer(ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pktLayers = append(pktLayers, transportLayer...)
|
||||||
|
|
||||||
|
if p.PayloadSize > 0 {
|
||||||
|
payload := make([]byte, p.PayloadSize)
|
||||||
|
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializePacket(pktLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||||
|
return &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
|
SrcIP: p.SrcIP,
|
||||||
|
DstIP: p.DstIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
switch p.Protocol {
|
||||||
|
case "tcp":
|
||||||
|
return p.buildTCPLayer(ip)
|
||||||
|
case "udp":
|
||||||
|
return p.buildUDPLayer(ip)
|
||||||
|
case "icmp":
|
||||||
|
return p.buildICMPLayer()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
|
Window: 65535,
|
||||||
|
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||||
|
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||||
|
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||||
|
RST: p.TCPState != nil && p.TCPState.RST,
|
||||||
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
|
}
|
||||||
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
|
}
|
||||||
|
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||||
|
icmp.Id = uint16(1)
|
||||||
|
icmp.Seq = uint16(1)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||||
|
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||||
|
switch protocol {
|
||||||
|
case fw.ProtocolTCP:
|
||||||
|
return int(layers.IPProtocolTCP)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
return int(layers.IPProtocolUDP)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
return int(layers.IPProtocolICMPv4)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||||
|
packetData, err := builder.Build()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.TracePacket(packetData, builder.Direction), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
|
// Initial packet decoding
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base packet info
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
trace.SourceIP = srcIP
|
||||||
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
|
// Determine protocol and ports
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
trace.Protocol = "TCP"
|
||||||
|
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
trace.Protocol = "UDP"
|
||||||
|
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
trace.Protocol = "ICMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
|
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||||
|
|
||||||
|
if direction == fw.RuleDirectionOUT {
|
||||||
|
return m.traceOutbound(packetData, trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||||
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.handleRouting(trace) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.nativeRouter {
|
||||||
|
return m.handleNativeRouter(trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||||
|
msg := "No existing connection found"
|
||||||
|
if allowed {
|
||||||
|
msg = m.buildConntrackStateMessage(d)
|
||||||
|
trace.AddResult(StageConntrack, msg, true)
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trace.AddResult(StageConntrack, msg, false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||||
|
msg := "Matched existing connection state"
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||||
|
flags&conntrack.TCPSyn != 0,
|
||||||
|
flags&conntrack.TCPAck != 0,
|
||||||
|
flags&conntrack.TCPRst != 0,
|
||||||
|
flags&conntrack.TCPFin != 0)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
|
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
msg := "Allowed by peer ACL rules"
|
||||||
|
if blocked {
|
||||||
|
msg = "Blocked by peer ACL rules"
|
||||||
|
}
|
||||||
|
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||||
|
|
||||||
|
if m.netstack {
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
|
if !m.routingEnabled {
|
||||||
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||||
|
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||||
|
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
|
msg := "Allowed by route ACLs"
|
||||||
|
if !allowed {
|
||||||
|
msg = "Blocked by route ACLs"
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
|
if allowed && m.forwarder != nil {
|
||||||
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||||
|
fwdAction := &ForwarderAction{
|
||||||
|
Action: action,
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
}
|
||||||
|
trace.AddResultWithForwarder(StageForwarding,
|
||||||
|
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
|
// will create or update the connection state
|
||||||
|
dropped := m.processOutgoingHooks(packetData)
|
||||||
|
if dropped {
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
|
} else {
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
@ -1,11 +1,14 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -14,44 +17,83 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
const (
|
||||||
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
|
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
var (
|
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||||
|
|
||||||
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||||
|
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||||
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]Rule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
|
type RouteRules []RouteRule
|
||||||
|
|
||||||
|
func (r RouteRules) Sort() {
|
||||||
|
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||||
|
// Deny rules come first
|
||||||
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.id, b.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
|
// outgoingRules is used for hooks only
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[string]RuleSet
|
||||||
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// indicates whether server routes are disabled
|
||||||
|
disableServerRoutes bool
|
||||||
|
// indicates whether we forward packets not destined for ourselves
|
||||||
|
routingEnabled bool
|
||||||
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
|
nativeRouter bool
|
||||||
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
|
// indicates whether wireguards runs in netstack mode
|
||||||
|
netstack bool
|
||||||
|
// indicates whether we forward local traffic to the native stack
|
||||||
|
localForwarding bool
|
||||||
|
|
||||||
|
localipmanager *localIPManager
|
||||||
|
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
|
forwarder *forwarder.Forwarder
|
||||||
|
logger *nblog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -68,22 +110,44 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||||
return create(iface)
|
return create(iface, nil, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
mgr, err := create(iface)
|
if nativeFirewall == nil {
|
||||||
|
return nil, errors.New("native firewall is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.nativeFirewall = nativeFirewall
|
|
||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func parseCreateEnv() (bool, bool) {
|
||||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
var disableConntrack, enableLocalForwarding bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||||
|
disableConntrack, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return disableConntrack, enableLocalForwarding
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
@ -99,53 +163,183 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
|
localipmanager: newLocalIPManager(),
|
||||||
|
disableServerRoutes: disableServerRoutes,
|
||||||
|
routingEnabled: false,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
netstack: netstack.IsEnabled(),
|
||||||
|
localForwarding: enableLocalForwarding,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// netstack needs the forwarder for local traffic
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
if err := m.initForwarder(); err != nil {
|
||||||
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.blockInvalidRouted(iface); err != nil {
|
||||||
|
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse wireguard network: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
|
if _, err := m.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
|
wgPrefix,
|
||||||
|
firewall.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.ActionDrop,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("block wg nte : %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) determineRouting() error {
|
||||||
|
var disableUspRouting, forceUserspaceRouter bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||||
|
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case disableUspRouting:
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
|
case m.disableServerRoutes:
|
||||||
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
|
case forceUserspaceRouter:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
|
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||||
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
|
default:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing enabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.routingEnabled && !m.nativeRouter {
|
||||||
|
return m.initForwarder()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
|
func (m *Manager) initForwarder() error {
|
||||||
|
if m.forwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
|
intf := m.wgIface.GetWGDevice()
|
||||||
|
if intf == nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return errors.New("forwarding not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||||
|
if err != nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder = forwarder
|
||||||
|
|
||||||
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Init(*statemanager.Manager) error {
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userspace routed packets are always SNATed to the inbound direction
|
||||||
|
// TODO: implement outbound SNAT
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
@ -156,17 +350,15 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
_ string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
direction: direction,
|
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
comment: comment,
|
||||||
}
|
}
|
||||||
@ -179,13 +371,8 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
r.matchByIP = false
|
r.matchByIP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) == 1 {
|
r.sPort = sPort
|
||||||
r.sPort = uint16(sPort.Values[0])
|
r.dPort = dPort
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) == 1 {
|
|
||||||
r.dPort = uint16(dPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@ -202,58 +389,80 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if direction == firewall.RuleDirectionIN {
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
if m.nativeFirewall == nil {
|
sources []netip.Prefix,
|
||||||
return nil, errRouteNotSupported
|
destination netip.Prefix,
|
||||||
}
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
m.mutex.Lock()
|
||||||
if m.nativeFirewall == nil {
|
defer m.mutex.Unlock()
|
||||||
return errRouteNotSupported
|
|
||||||
|
ruleID := uuid.New().String()
|
||||||
|
rule := RouteRule{
|
||||||
|
id: ruleID,
|
||||||
|
sources: sources,
|
||||||
|
destination: destination,
|
||||||
|
proto: proto,
|
||||||
|
srcPort: sPort,
|
||||||
|
dstPort: dPort,
|
||||||
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.routeRules = append(m.routeRules, rule)
|
||||||
|
m.routeRules.Sort()
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := rule.GetRuleID()
|
||||||
|
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||||
|
return r.id == ruleID
|
||||||
|
})
|
||||||
|
if idx < 0 {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*PeerRule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.direction == firewall.RuleDirectionIN {
|
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
} else {
|
|
||||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -276,10 +485,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
|||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.incomingRules)
|
return m.dropFilter(packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
|
func (m *Manager) UpdateLocalIPs() error {
|
||||||
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
@ -300,18 +513,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always process UDP hooks
|
// Track all protocols if stateful mode is enabled
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
|
||||||
// Track UDP state only if enabled
|
|
||||||
if m.stateful {
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track other protocols only if stateful mode is enabled
|
|
||||||
if m.stateful {
|
if m.stateful {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@ -319,6 +525,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process UDP hooks even if stateful mode is disabled
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,7 +591,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
|||||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
|
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -400,10 +611,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
// If it returns true, the packet should be dropped.
|
||||||
// TODO: Disable router if --disable-server-router is set
|
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@ -416,39 +626,129 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if srcIP == nil {
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
return false
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection state only if enabled
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.applyRules(srcIP, packetData, rules, d)
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalTraffic handles local traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
|
if m.netstack {
|
||||||
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
// pass to virtual tcp/ip stack to be picked up by listeners
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder == nil {
|
||||||
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't process this packet further
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRoutedTraffic handles routed traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
// Drop if routing is disabled
|
||||||
|
if !m.routingEnabled {
|
||||||
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass to native stack if native router is enabled or forced
|
||||||
|
if m.nativeRouter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||||
|
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||||
|
srcIP, srcPort, dstIP, dstPort, proto)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return firewall.ProtocolTCP
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return firewall.ProtocolUDP
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return firewall.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return firewall.ProtocolALL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||||
|
default:
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
|
||||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
@ -483,7 +783,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||||
|
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||||
|
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType := d.icmp4.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
|
if m.isSpecialICMP(d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||||
return filter
|
return filter
|
||||||
}
|
}
|
||||||
@ -500,7 +815,24 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||||
|
if rulePort == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rulePort.IsRange {
|
||||||
|
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range rulePort.Values {
|
||||||
|
if p == packetPort {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
@ -517,13 +849,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
|
|
||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
|
||||||
return rule.drop, true
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@ -533,13 +859,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return rule.udpHook(packetData), true
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
|
||||||
return rule.drop, true
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
@ -549,6 +869,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||||
|
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||||
|
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||||
|
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||||
|
return rule.action == firewall.ActionAccept
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
if !rule.destination.Contains(dstAddr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceMatched := false
|
||||||
|
for _, src := range rule.sources {
|
||||||
|
if src.Contains(srcAddr) {
|
||||||
|
sourceMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sourceMatched {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||||
|
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
m.wgNetwork = network
|
m.wgNetwork = network
|
||||||
@ -560,13 +925,12 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||||
) string {
|
) string {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: dPort,
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
direction: firewall.RuleDirectionOUT,
|
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
@ -577,14 +941,13 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
r.direction = firewall.RuleDirectionIN
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
@ -596,21 +959,62 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
for _, arr := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
delete(arr, r.id)
|
||||||
return m.DeletePeerRule(&rule)
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, arr := range m.outgoingRules {
|
for _, arr := range m.outgoingRules {
|
||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
delete(arr, r.id)
|
||||||
return m.DeletePeerRule(&rule)
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(level log.Level) {
|
||||||
|
if m.logger != nil {
|
||||||
|
m.logger.SetLevel(nblog.Level(level))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.determineRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
// don't stop forwarder if in use by netstack
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder.Stop()
|
||||||
|
m.forwarder = nil
|
||||||
|
|
||||||
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
//go:build uspbench
|
||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -91,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
|
fw.ActionAccept, "", "allow all")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@ -112,9 +115,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []int{1024 + i}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []int{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
|
fw.ActionAccept, "", "explicit return")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -126,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||||
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
|
fw.ActionDrop, "", "default drop")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, manager.incomingRules)
|
manager.dropFilter(testIn)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack)
|
||||||
@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -588,9 +591,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []int{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
nil,
|
nil,
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
fw.ActionAccept, "", "return traffic")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -679,9 +682,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []int{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
nil,
|
nil,
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
fw.ActionAccept, "", "return traffic")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -797,9 +800,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []int{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
nil,
|
nil,
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
fw.ActionAccept, "", "return traffic")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -884,9 +887,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []int{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
nil,
|
nil,
|
||||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
fw.ActionAccept, "", "return traffic")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
|||||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouteACLs(b *testing.B) {
|
||||||
|
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||||
|
|
||||||
|
// Add several route rules to simulate real-world scenario
|
||||||
|
rules := []struct {
|
||||||
|
sources []netip.Prefix
|
||||||
|
dest netip.Prefix
|
||||||
|
proto fw.Protocol
|
||||||
|
port *fw.Port
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
port: &fw.Port{Values: []uint16{80, 443}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/12"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
port: &fw.Port{Values: []uint16{53}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
_, err := manager.AddRouteFiltering(
|
||||||
|
r.sources,
|
||||||
|
r.dest,
|
||||||
|
r.proto,
|
||||||
|
nil,
|
||||||
|
r.port,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases that exercise different matching scenarios
|
||||||
|
cases := []struct {
|
||||||
|
srcIP string
|
||||||
|
dstIP string
|
||||||
|
proto fw.Protocol
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||||
|
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||||
|
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||||
|
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for _, tc := range cases {
|
||||||
|
srcIP := net.ParseIP(tc.srcIP)
|
||||||
|
dstIP := net.ParseIP(tc.dstIP)
|
||||||
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,17 +9,38 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() iface.WGAddress
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
|
if i.GetWGDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||||
|
if i.GetDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||||
@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -69,12 +90,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -96,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -104,38 +124,16 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule 2"
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip = net.ParseIP("192.168.1.1")
|
|
||||||
proto = fw.ProtocolTCP
|
|
||||||
port = &fw.Port{Values: []int{80}}
|
|
||||||
direction = fw.RuleDirectionIN
|
|
||||||
action = fw.ActionDrop
|
|
||||||
comment = "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule {
|
|
||||||
err = m.DeletePeerRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
@ -189,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
@ -217,18 +215,14 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.dPort != addedRule.dPort {
|
if tt.dPort != addedRule.dPort.Values[0] {
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.expDir != addedRule.direction {
|
|
||||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
if addedRule.udpHook == nil {
|
||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
@ -242,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -250,12 +244,11 @@ func TestManagerReset(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -275,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -289,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -327,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
|
if m.dropFilter(buf.Bytes()) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -346,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface)
|
manager, err := Create(iface, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@ -392,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -400,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Reset(nil))
|
||||||
}()
|
}()
|
||||||
@ -478,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@ -492,12 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@ -509,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -518,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@ -639,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@ -710,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(testBuf.Bytes())
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||||
}
|
}
|
||||||
|
|
||||||
var localAddrsForUnspecified []net.Addr
|
mux := &UDPMuxDefault{
|
||||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
|
||||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
|
||||||
} else if ok && addr.IP.IsUnspecified() {
|
|
||||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
|
||||||
// it will break the applications that are already using unspecified UDP connection
|
|
||||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
|
||||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
|
||||||
var networks []ice.NetworkType
|
|
||||||
switch {
|
|
||||||
|
|
||||||
case addr.IP.To16() != nil:
|
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
|
||||||
|
|
||||||
case addr.IP.To4() != nil:
|
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
|
||||||
|
|
||||||
default:
|
|
||||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
|
||||||
}
|
|
||||||
if len(networks) > 0 {
|
|
||||||
if params.Net == nil {
|
|
||||||
var err error
|
|
||||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
|
||||||
params.Logger.Errorf("failed to get create network: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
|
||||||
if err == nil {
|
|
||||||
for _, ip := range ips {
|
|
||||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPMuxDefault{
|
|
||||||
addressMap: map[string][]*udpMuxedConn{},
|
addressMap: map[string][]*udpMuxedConn{},
|
||||||
params: params,
|
params: params,
|
||||||
connsIPv4: make(map[string]*udpMuxedConn),
|
connsIPv4: make(map[string]*udpMuxedConn),
|
||||||
@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mux.updateLocalAddresses()
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||||
|
var localAddrsForUnspecified []net.Addr
|
||||||
|
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||||
|
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||||
|
} else if ok && addr.IP.IsUnspecified() {
|
||||||
|
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||||
|
// it will break the applications that are already using unspecified UDP connection
|
||||||
|
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||||
|
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||||
|
var networks []ice.NetworkType
|
||||||
|
switch {
|
||||||
|
|
||||||
|
case addr.IP.To16() != nil:
|
||||||
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||||
|
|
||||||
|
case addr.IP.To4() != nil:
|
||||||
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||||
|
|
||||||
|
default:
|
||||||
|
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
|
||||||
|
}
|
||||||
|
if len(networks) > 0 {
|
||||||
|
if m.params.Net == nil {
|
||||||
|
var err error
|
||||||
|
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||||
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
|
||||||
|
if err == nil {
|
||||||
|
for _, ip := range ips {
|
||||||
|
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.localAddrsForUnspecified = localAddrsForUnspecified
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||||
@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
|||||||
|
|
||||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||||
|
m.updateLocalAddresses()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
if len(m.localAddrsForUnspecified) > 0 {
|
if len(m.localAddrsForUnspecified) > 0 {
|
||||||
return m.localAddrsForUnspecified
|
return slices.Clone(m.localAddrsForUnspecified)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []net.Addr{m.LocalAddr()}
|
return []net.Addr{m.LocalAddr()}
|
||||||
@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
|||||||
// creates the connection if an existing one can't be found
|
// creates the connection if an existing one can't be found
|
||||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||||
// don't check addr for mux using unspecified address
|
// don't check addr for mux using unspecified address
|
||||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
m.mu.Lock()
|
||||||
|
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||||
|
m.mu.Unlock()
|
||||||
|
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||||
return nil, fmt.Errorf("invalid address %s", addr.String())
|
return nil, fmt.Errorf("invalid address %s", addr.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "wt0"
|
const WgInterfaceDefault = "wt0"
|
||||||
|
@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "utun100"
|
const WgInterfaceDefault = "utun100"
|
||||||
|
@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
@ -3,6 +3,10 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -15,4 +19,6 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
|
GetNet() *netstack.Net
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -63,7 +64,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||||
|
|
||||||
log.Debugf("attaching to interface %v", name)
|
log.Debugf("attaching to interface %v", name)
|
||||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
@ -130,6 +131,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func routesToString(routes []string) string {
|
func routesToString(routes []string) string {
|
||||||
return strings.Join(routes, ";")
|
return strings.Join(routes, ";")
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -117,6 +118,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||||
@ -138,3 +144,7 @@ func (t *TunDevice) assignAddr() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -64,7 +65,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
|
|
||||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||||
log.Debug("Attaching to interface")
|
log.Debug("Attaching to interface")
|
||||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
@ -131,3 +132,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
|||||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -33,8 +35,6 @@ type TunKernelDevice struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||||
checkUser()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &TunKernelDevice{
|
return &TunKernelDevice{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -153,6 +153,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device, not applicable for kernel devices
|
||||||
|
func (t *TunKernelDevice) Device() *device.Device {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -161,3 +166,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
|||||||
func (t *TunKernelDevice) assignAddr() error {
|
func (t *TunKernelDevice) assignAddr() error {
|
||||||
return t.link.assignAddr(t.address)
|
return t.link.assignAddr(t.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -8,10 +8,12 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
@ -25,9 +27,11 @@ type TunNetstackDevice struct {
|
|||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
nsTun *netstack.NetStackTun
|
nsTun *nbnetstack.NetStackTun
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *bind.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
|
|
||||||
|
net *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||||
@ -43,13 +47,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||||
log.Info("create netstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
|
|
||||||
tunIface, err := t.nsTun.Create()
|
// TODO: get from service listener runtime IP
|
||||||
|
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
|
tunIface, net, err := t.nsTun.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating tun device: %s", err)
|
return nil, fmt.Errorf("error creating tun device: %s", err)
|
||||||
}
|
}
|
||||||
t.filteredDevice = newDeviceFilter(tunIface)
|
t.filteredDevice = newDeviceFilter(tunIface)
|
||||||
|
t.net = net
|
||||||
|
|
||||||
t.device = device.NewDevice(
|
t.device = device.NewDevice(
|
||||||
t.filteredDevice,
|
t.filteredDevice,
|
||||||
@ -117,3 +127,12 @@ func (t *TunNetstackDevice) DeviceName() string {
|
|||||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunNetstackDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||||
|
return t.net
|
||||||
|
}
|
||||||
|
@ -4,12 +4,11 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -32,8 +31,6 @@ type USPDevice struct {
|
|||||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
checkUser()
|
|
||||||
|
|
||||||
return &USPDevice{
|
return &USPDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@ -128,6 +125,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *USPDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *USPDevice) assignAddr() error {
|
func (t *USPDevice) assignAddr() error {
|
||||||
link := newWGLink(t.name)
|
link := newWGLink(t.name)
|
||||||
@ -135,11 +137,6 @@ func (t *USPDevice) assignAddr() error {
|
|||||||
return link.assignAddr(t.address)
|
return link.assignAddr(t.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUser() {
|
func (t *USPDevice) GetNet() *netstack.Net {
|
||||||
if runtime.GOOS == "freebsd" {
|
return nil
|
||||||
euid := os.Geteuid()
|
|
||||||
if euid != 0 {
|
|
||||||
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -150,6 +151,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||||
if t.nativeTunDevice == nil {
|
if t.nativeTunDevice == nil {
|
||||||
return "", fmt.Errorf("interface has not been initialized yet")
|
return "", fmt.Errorf("interface has not been initialized yet")
|
||||||
@ -169,3 +175,7 @@ func (t *TunDevice) assignAddr() error {
|
|||||||
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||||
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -13,4 +17,6 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
|
GetNet() *netstack.Net
|
||||||
}
|
}
|
||||||
|
@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
|
|||||||
return fmt.Errorf("set interface addr: %w", err)
|
return fmt.Errorf("set interface addr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,8 +9,11 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -203,6 +206,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return w.tun.FilteredDevice()
|
return w.tun.FilteredDevice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWGDevice returns the WireGuard device
|
||||||
|
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return w.tun.Device()
|
||||||
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats(peerKey)
|
||||||
@ -234,3 +242,11 @@ func (w *WGIface) waitUntilRemoved() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNet returns the netstack.Net for the netstack device
|
||||||
|
func (w *WGIface) GetNet() *netstack.Net {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
return w.tun.GetNet()
|
||||||
|
}
|
||||||
|
@ -1,112 +0,0 @@
|
|||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockWGIface struct {
|
|
||||||
CreateFunc func() error
|
|
||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
|
||||||
IsUserspaceBindFunc func() bool
|
|
||||||
NameFunc func() string
|
|
||||||
AddressFunc func() device.WGAddress
|
|
||||||
ToInterfaceFunc func() *net.Interface
|
|
||||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
|
||||||
UpdateAddrFunc func(newAddr string) error
|
|
||||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
|
||||||
RemovePeerFunc func(peerKey string) error
|
|
||||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
|
||||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
|
||||||
CloseFunc func() error
|
|
||||||
SetFilterFunc func(filter device.PacketFilter) error
|
|
||||||
GetFilterFunc func() device.PacketFilter
|
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
|
||||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
|
||||||
return m.GetInterfaceGUIDStringFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) Create() error {
|
|
||||||
return m.CreateFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
|
||||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
|
||||||
return m.IsUserspaceBindFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) Name() string {
|
|
||||||
return m.NameFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) Address() device.WGAddress {
|
|
||||||
return m.AddressFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
|
||||||
return m.ToInterfaceFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|
||||||
return m.UpFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
|
||||||
return m.UpdateAddrFunc(newAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
|
||||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
|
||||||
return m.RemovePeerFunc(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
|
||||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
|
||||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) Close() error {
|
|
||||||
return m.CloseFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
|
||||||
return m.SetFilterFunc(filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
|
||||||
return m.GetFilterFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
|
||||||
return m.GetDeviceFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
|
||||||
return m.GetStatsFunc(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IWGIface interface {
|
|
||||||
Create() error
|
|
||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
Name() string
|
|
||||||
Address() device.WGAddress
|
|
||||||
ToInterface() *net.Interface
|
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
|
||||||
UpdateAddr(newAddr string) error
|
|
||||||
GetProxy() wgproxy.Proxy
|
|
||||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
|
||||||
RemovePeer(peerKey string) error
|
|
||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
|
||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
|
||||||
Close() error
|
|
||||||
SetFilter(filter device.PacketFilter) error
|
|
||||||
GetFilter() device.PacketFilter
|
|
||||||
GetDevice() *device.FilteredDevice
|
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
GetInterfaceGUIDString() (string, error)
|
|
||||||
}
|
|
@ -8,9 +8,11 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||||
|
|
||||||
// IsEnabled todo: move these function to cmd layer
|
// IsEnabled todo: move these function to cmd layer
|
||||||
func IsEnabled() bool {
|
func IsEnabled() bool {
|
||||||
return os.Getenv("NB_USE_NETSTACK_MODE") == "true"
|
return os.Getenv(EnvUseNetstackMode) == "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenAddr() string {
|
func ListenAddr() string {
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address string
|
address net.IP
|
||||||
|
dnsAddress net.IP
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@ -17,29 +24,48 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
|
dnsAddress: dnsAddress,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(t.address)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||||
|
}
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{netip.MustParseAddr(t.address)},
|
[]netip.Addr{addr.Unmap()},
|
||||||
[]netip.Addr{},
|
[]netip.Addr{dnsAddr.Unmap()},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
t.tundev = nsTunDev
|
t.tundev = nsTunDev
|
||||||
|
|
||||||
|
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
||||||
|
}
|
||||||
|
if skipProxy {
|
||||||
|
return nsTunDev, tunNet, nil
|
||||||
|
}
|
||||||
|
|
||||||
dialer := NewNSDialer(tunNet)
|
dialer := NewNSDialer(tunNet)
|
||||||
t.proxy, err = NewSocks5(dialer)
|
t.proxy, err = NewSocks5(dialer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = t.tundev.Close()
|
_ = t.tundev.Close()
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -49,7 +75,7 @@ func (t *NetStackTun) Create() (tun.Device, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, nil
|
return nsTunDev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.rollBack(newRulePairs)
|
d.rollBack(newRulePairs)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if len(rules) > 0 {
|
if len(rulePair) > 0 {
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
newRulePairs[pairID] = rulePair
|
newRulePairs[pairID] = rulePair
|
||||||
}
|
}
|
||||||
@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var port *firewall.Port
|
var port *firewall.Port
|
||||||
if r.Port != "" {
|
if !portInfoEmpty(r.PortInfo) {
|
||||||
|
port = convertPortInfo(r.PortInfo)
|
||||||
|
} else if r.Port != "" {
|
||||||
|
// old version of management, single port
|
||||||
value, err := strconv.Atoi(r.Port)
|
value, err := strconv.Atoi(r.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||||
}
|
}
|
||||||
port = &firewall.Port{
|
port = &firewall.Port{
|
||||||
Values: []int{value},
|
Values: []uint16{uint16(value)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@ -300,6 +305,22 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
return ruleID, rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||||
|
if portInfo == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch portInfo.GetPortSelection().(type) {
|
||||||
|
case *mgmProto.PortInfo_Port:
|
||||||
|
return portInfo.GetPort() == 0
|
||||||
|
case *mgmProto.PortInfo_Range_:
|
||||||
|
r := portInfo.GetRange()
|
||||||
|
return r == nil || r.Start == 0 || r.End == 0
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
@ -308,25 +329,12 @@ func (d *DefaultManager) addInRules(
|
|||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||||
rule, err := d.firewall.AddPeerFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddPeerFiltering(
|
return rule, nil
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
@ -337,25 +345,16 @@ func (d *DefaultManager) addOutRules(
|
|||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return rules, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddPeerFiltering(
|
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||||
@ -508,7 +507,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||||
@ -559,14 +558,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
|||||||
|
|
||||||
if portInfo.GetPort() != 0 {
|
if portInfo.GetPort() != 0 {
|
||||||
return &firewall.Port{
|
return &firewall.Port{
|
||||||
Values: []int{int(portInfo.GetPort())},
|
Values: []uint16{uint16(int(portInfo.GetPort()))},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if portInfo.GetRange() != nil {
|
if portInfo.GetRange() != nil {
|
||||||
return &firewall.Port{
|
return &firewall.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@ -119,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 1 {
|
||||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@ -356,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 4 {
|
if len(acl.peerRulesPairs) != 3 {
|
||||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
iface "github.com/netbirdio/netbird/client/iface"
|
iface "github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetDevice")
|
||||||
|
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevice indicates an expected call of GetDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||||
|
ret0, _ := ret[0].(*wgdevice.Device)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||||
|
}
|
||||||
|
@ -2,6 +2,8 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -11,7 +13,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HostedGrantType grant type for device flow on Hosted
|
// HostedGrantType grant type for device flow on Hosted
|
||||||
@ -56,6 +61,18 @@ func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*Devi
|
|||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
certPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil || certPool == nil {
|
||||||
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
|
certPool = embeddedroots.Get()
|
||||||
|
} else {
|
||||||
|
log.Debug("Using system certificate pool.")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpTransport.TLSClientConfig = &tls.Config{
|
||||||
|
RootCAs: certPool,
|
||||||
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,6 +68,12 @@ type ConfigInput struct {
|
|||||||
DisableServerRoutes *bool
|
DisableServerRoutes *bool
|
||||||
DisableDNS *bool
|
DisableDNS *bool
|
||||||
DisableFirewall *bool
|
DisableFirewall *bool
|
||||||
|
|
||||||
|
BlockLANAccess *bool
|
||||||
|
|
||||||
|
DisableNotifications *bool
|
||||||
|
|
||||||
|
DNSLabels domain.List
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@ -89,6 +97,12 @@ type Config struct {
|
|||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
|
||||||
|
BlockLANAccess bool
|
||||||
|
|
||||||
|
DisableNotifications *bool
|
||||||
|
|
||||||
|
DNSLabels domain.List
|
||||||
|
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
@ -455,6 +469,33 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
|
||||||
|
if *input.BlockLANAccess {
|
||||||
|
log.Infof("blocking LAN access")
|
||||||
|
} else {
|
||||||
|
log.Infof("allowing LAN access")
|
||||||
|
}
|
||||||
|
config.BlockLANAccess = *input.BlockLANAccess
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||||
|
if *input.DisableNotifications {
|
||||||
|
log.Infof("disabling notifications")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling notifications")
|
||||||
|
}
|
||||||
|
config.DisableNotifications = input.DisableNotifications
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.DisableNotifications == nil {
|
||||||
|
disabled := true
|
||||||
|
config.DisableNotifications = &disabled
|
||||||
|
log.Infof("setting notifications to disabled by default")
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.ClientCertKeyPath != "" {
|
if input.ClientCertKeyPath != "" {
|
||||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||||
updated = true
|
updated = true
|
||||||
@ -475,6 +516,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
|
||||||
|
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
|
||||||
|
input.DNSLabels.SafeString(),
|
||||||
|
config.DNSLabels.SafeString())
|
||||||
|
config.DNSLabels = input.DNSLabels
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
@ -31,6 +32,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,13 +61,8 @@ func NewConnectClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run() error {
|
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||||
return c.run(MobileDependency{}, nil, nil)
|
return c.run(MobileDependency{}, runningChan)
|
||||||
}
|
|
||||||
|
|
||||||
// RunWithProbes runs the client's main logic with probes attached
|
|
||||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
|
||||||
return c.run(MobileDependency{}, probes, runningChan)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunOnAndroid with main logic on mobile system
|
// RunOnAndroid with main logic on mobile system
|
||||||
@ -84,7 +81,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) RunOniOS(
|
func (c *ConnectClient) RunOniOS(
|
||||||
@ -102,18 +99,30 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
rec := c.statusRecorder
|
||||||
|
if rec != nil {
|
||||||
|
rec.PublishEvent(
|
||||||
|
cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM,
|
||||||
|
"panic occurred",
|
||||||
|
"The Netbird service panicked. Please restart the service and submit a bug report with the client logs.",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
|
nbnet.Init()
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@ -182,8 +191,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
@ -204,8 +213,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|
||||||
signalURL := fmt.Sprintf("%s://%s",
|
signalURL := fmt.Sprintf("%s://%s",
|
||||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
strings.ToLower(loginResp.GetNetbirdConfig().GetSignal().GetProtocol().String()),
|
||||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
loginResp.GetNetbirdConfig().GetSignal().GetUri(),
|
||||||
)
|
)
|
||||||
|
|
||||||
c.statusRecorder.UpdateSignalAddress(signalURL)
|
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||||
@ -216,8 +225,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
c.statusRecorder.MarkSignalDisconnected(err)
|
c.statusRecorder.MarkSignalDisconnected(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Netbird config in hand connect (just a connection, no stream yet) Signal
|
||||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
signalClient, err := connectToSignal(engineCtx, loginResp.GetNetbirdConfig(), myPrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@ -261,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
@ -316,7 +325,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||||
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
|
relayCfg := loginResp.GetNetbirdConfig().GetRelay()
|
||||||
if relayCfg == nil {
|
if relayCfg == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -420,6 +429,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DisableServerRoutes: config.DisableServerRoutes,
|
DisableServerRoutes: config.DisableServerRoutes,
|
||||||
DisableDNS: config.DisableDNS,
|
DisableDNS: config.DisableDNS,
|
||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
|
||||||
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@ -443,7 +454,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// connectToSignal creates Signal Service client and established a connection
|
// connectToSignal creates Signal Service client and established a connection
|
||||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||||
var sigTLSEnabled bool
|
var sigTLSEnabled bool
|
||||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||||
sigTLSEnabled = true
|
sigTLSEnabled = true
|
||||||
@ -460,8 +471,8 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
|||||||
return signalClient, nil
|
return signalClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -469,7 +480,16 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
|||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
sysInfo.SetFlags(
|
||||||
|
config.RosenpassEnabled,
|
||||||
|
config.RosenpassPermissive,
|
||||||
|
config.ServerSSHAllowed,
|
||||||
|
config.DisableClientRoutes,
|
||||||
|
config.DisableServerRoutes,
|
||||||
|
config.DisableDNS,
|
||||||
|
config.DisableFirewall,
|
||||||
|
)
|
||||||
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||||
|
ip := net.ParseIP(aRecord.RData)
|
||||||
|
if ip == nil || ip.To4() == nil {
|
||||||
|
return nbdns.SimpleRecord{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ipNet.Contains(ip) {
|
||||||
|
return nbdns.SimpleRecord{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ipOctets := strings.Split(ip.String(), ".")
|
||||||
|
slices.Reverse(ipOctets)
|
||||||
|
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||||
|
|
||||||
|
return nbdns.SimpleRecord{
|
||||||
|
Name: rdnsName,
|
||||||
|
Type: int(dns.TypePTR),
|
||||||
|
Class: aRecord.Class,
|
||||||
|
TTL: aRecord.TTL,
|
||||||
|
RData: dns.Fqdn(aRecord.Name),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||||
|
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||||
|
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||||
|
maskOnes, _ := ipNet.Mask.Size()
|
||||||
|
|
||||||
|
// round up to nearest byte
|
||||||
|
octetsToUse := (maskOnes + 7) / 8
|
||||||
|
|
||||||
|
octets := strings.Split(networkIP.String(), ".")
|
||||||
|
if octetsToUse > len(octets) {
|
||||||
|
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseOctets := make([]string, octetsToUse)
|
||||||
|
for i := 0; i < octetsToUse; i++ {
|
||||||
|
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||||
|
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||||
|
for _, zone := range config.CustomZones {
|
||||||
|
if zone.Domain == zoneName {
|
||||||
|
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||||
|
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||||
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
|
for _, zone := range config.CustomZones {
|
||||||
|
for _, record := range zone.Records {
|
||||||
|
if record.Type != int(dns.TypeA) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||||
|
records = append(records, ptrRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||||
|
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||||
|
zoneName, err := generateReverseZoneName(ipNet)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if zoneExists(config, zoneName) {
|
||||||
|
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
records := collectPTRRecords(config, ipNet)
|
||||||
|
|
||||||
|
reverseZone := nbdns.CustomZone{
|
||||||
|
Domain: zoneName,
|
||||||
|
Records: records,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||||
|
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||||
|
}
|
@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
|||||||
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
return ErrRouteAllWithoutNameserverGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
if !backupFileExist {
|
if !backupFileExist {
|
||||||
@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error {
|
|||||||
return f.restore()
|
return f.restore()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fileConfigurator) string() string {
|
||||||
|
return "file"
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) backup() error {
|
func (f *fileConfigurator) backup() error {
|
||||||
stats, err := os.Stat(defaultResolvConfPath)
|
stats, err := os.Stat(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityDNSRoute = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityMatchDomain = 50
|
||||||
PriorityDefault = 0
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type SubdomainMatcher interface {
|
type SubdomainMatcher interface {
|
||||||
@ -26,7 +26,6 @@ type HandlerEntry struct {
|
|||||||
Pattern string
|
Pattern string
|
||||||
OrigPattern string
|
OrigPattern string
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
StopHandler handlerWithStop
|
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
|
||||||
c.handlers[i].StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -101,21 +97,33 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
OrigPattern: origPattern,
|
OrigPattern: origPattern,
|
||||||
IsWildcard: isWildcard,
|
IsWildcard: isWildcard,
|
||||||
StopHandler: stopHandler,
|
|
||||||
MatchSubdomains: matchSubdomains,
|
MatchSubdomains: matchSubdomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert handler in priority order
|
pos := c.findHandlerPosition(entry)
|
||||||
pos := 0
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
for i, h := range c.handlers {
|
|
||||||
if h.Priority < priority {
|
|
||||||
pos = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pos = i + 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||||
|
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||||
|
for i, h := range c.handlers {
|
||||||
|
// prio first
|
||||||
|
if h.Priority < newEntry.Priority {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// domain specificity next
|
||||||
|
if h.Priority == newEntry.Priority {
|
||||||
|
newDots := strings.Count(newEntry.Pattern, ".")
|
||||||
|
existingDots := strings.Count(h.Pattern, ".")
|
||||||
|
if newDots > existingDots {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add at end
|
||||||
|
return len(c.handlers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveHandler removes a handler for the given pattern and priority
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
@ -129,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
|
||||||
entry.StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -167,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !matched {
|
if !matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
|
@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
dnsRouteHandler := &nbdns.MockHandler{}
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
pattern = "*." + tt.handlerDomain[2:]
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3 := &nbdns.MockHandler{}
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers[op.priority] = handler
|
handlers[op.priority] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state with all three handlers
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
handler = mockHandler
|
handler = mockHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
chain.AddHandler(pattern, handler, h.priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
@ -677,3 +677,156 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
ops []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedMatch string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "more specific domain matches first",
|
||||||
|
scenario: "sub.example.com should match before example.com",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "more specific domain matches first, both match subdomains",
|
||||||
|
scenario: "sub.example.com should match before example.com",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "maintain specificity order after removal",
|
||||||
|
scenario: "after removing most specific, should fall back to less specific",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "test.sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "priority overrides specificity",
|
||||||
|
scenario: "less specific domain with higher priority should match first",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "equal priority respects specificity",
|
||||||
|
scenario: "with equal priority, more specific domain should match",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "specific matches before wildcard",
|
||||||
|
scenario: "specific domain should match before wildcard at same priority",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||||
|
|
||||||
|
for _, op := range tt.ops {
|
||||||
|
if op.action == "add" {
|
||||||
|
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||||
|
handlers[op.pattern] = handler
|
||||||
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
|
} else {
|
||||||
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup handler expectations
|
||||||
|
for pattern, handler := range handlers {
|
||||||
|
if pattern == tt.expectedMatch {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetReply(r)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
for pattern, handler := range handlers {
|
||||||
|
if pattern == tt.expectedMatch {
|
||||||
|
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
|
||||||
|
} else {
|
||||||
|
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user