Merge branch 'main' into routes-get-account-refactoring

# Conflicts:
#	.github/workflows/golang-test-linux.yml
#	go.mod
#	go.sum
#	management/server/account.go
#	management/server/account_test.go
#	management/server/http/handler.go
#	management/server/http/handlers/peers/peers_handler.go
#	management/server/http/middleware/auth_middleware.go
#	management/server/http/middleware/auth_middleware_test.go
#	management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
#	management/server/http/testing/benchmarks/users_handler_benchmark_test.go
#	management/server/integrated_validator.go
#	management/server/mock_server/account_mock.go
#	management/server/peer.go
#	management/server/status/error.go
#	management/server/store/sql_store.go
#	management/server/store/sql_store_test.go
#	management/server/user.go
This commit is contained in:
bcmmbaga 2025-02-24 13:39:38 +00:00
commit e326cc7412
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
411 changed files with 23713 additions and 8519 deletions

View File

@ -1,4 +1,4 @@
name: Test Code Darwin name: "Darwin"
on: on:
push: push:
@ -12,9 +12,7 @@ concurrency:
jobs: jobs:
test: test:
strategy: name: "Client / Unit"
matrix:
store: ['sqlite']
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
@ -44,4 +42,5 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@ -1,5 +1,4 @@
name: "FreeBSD"
name: Test Code FreeBSD
on: on:
push: push:
@ -13,6 +12,7 @@ concurrency:
jobs: jobs:
test: test:
name: "Client / Unit"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -24,7 +24,7 @@ jobs:
copyback: false copyback: false
release: "14.1" release: "14.1"
prepare: | prepare: |
pkg install -y go pkg install -y go pkgconf xorg
# -x - to print all executed commands # -x - to print all executed commands
# -e - to faile on first error # -e - to faile on first error
@ -33,7 +33,7 @@ jobs:
time go build -o netbird client/main.go time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use` # NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/... time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...

View File

@ -1,4 +1,4 @@
name: Test Code Linux name: Linux
on: on:
push: push:
@ -12,11 +12,21 @@ concurrency:
jobs: jobs:
build-cache: build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@ -38,7 +48,6 @@ jobs:
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies - name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
@ -89,6 +98,7 @@ jobs:
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test: test:
name: "Client / Unit"
needs: [build-cache] needs: [build-cache]
strategy: strategy:
fail-fast: false fail-fast: false
@ -134,9 +144,116 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management: test_management:
name: "Management / Unit"
needs: [ build-cache ] needs: [ build-cache ]
strategy: strategy:
fail-fast: false fail-fast: false
@ -194,10 +311,17 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 10m ./management/...
benchmark: benchmark:
name: "Management / Benchmark"
needs: [ build-cache ] needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -254,10 +378,17 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ] needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -312,12 +443,21 @@ jobs:
- name: download mysql image - name: download mysql image
if: matrix.store == 'mysql' if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management) run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test: api_integration_test:
name: "Management / Integration"
needs: [ build-cache ] needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -363,9 +503,15 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management) run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 10m ./management/...
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ] needs: [ build-cache ]
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:

View File

@ -1,4 +1,4 @@
name: Test Code Windows name: "Windows"
on: on:
push: push:
@ -14,6 +14,7 @@ concurrency:
jobs: jobs:
test: test:
name: "Client / Unit"
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
@ -65,7 +66,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@ -1,4 +1,4 @@
name: golangci-lint name: Lint
on: [pull_request] on: [pull_request]
permissions: permissions:
@ -27,7 +27,14 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [macos-latest, windows-latest, ubuntu-latest] os: [macos-latest, windows-latest, ubuntu-latest]
name: lint include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
timeout-minutes: 15 timeout-minutes: 15
steps: steps:

View File

@ -1,4 +1,4 @@
name: Mobile build validation name: Mobile
on: on:
push: push:
@ -12,6 +12,7 @@ concurrency:
jobs: jobs:
android_build: android_build:
name: "Android / Build"
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -47,6 +48,7 @@ jobs:
CGO_ENABLED: 0 CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build: ios_build:
name: "iOS / Build"
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository

View File

@ -9,10 +9,10 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.17" SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "NetBird GmbH"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}

1
.gitignore vendored
View File

@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store .DS_Store
vendor/

View File

@ -103,7 +103,7 @@ linters:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers - predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. # - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements - wastedassign # wastedassign finds wasted assignment statements
issues: issues:
# Maximum count of issues with the same text. # Maximum count of issues with the same text.

View File

@ -50,10 +50,12 @@ nfpms:
- netbird-ui - netbird-ui
formats: formats:
- deb - deb
scripts:
postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png - src: client/ui/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@ -67,10 +69,12 @@ nfpms:
- netbird-ui - netbird-ui
formats: formats:
- rpm - rpm
scripts:
postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png - src: client/ui/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@ -1,3 +1,3 @@
Mikhail Bragin (https://github.com/braginini) Mikhail Bragin (https://github.com/braginini)
Maycon Santos (https://github.com/mlsmaycon) Maycon Santos (https://github.com/mlsmaycon)
Wiretrustee UG (haftungsbeschränkt) NetBird GmbH

View File

@ -3,10 +3,10 @@
We are incredibly thankful for the contributions we receive from the community. We are incredibly thankful for the contributions we receive from the community.
We require our external contributors to sign a Contributor License Agreement ("CLA") in We require our external contributors to sign a Contributor License Agreement ("CLA") in
order to ensure that our projects remain licensed under Free and Open Source licenses such order to ensure that our projects remain licensed under Free and Open Source licenses such
as BSD-3 while allowing Wiretrustee to build a sustainable business. as BSD-3 while allowing NetBird to build a sustainable business.
Wiretrustee is committed to having a true Open Source Software ("OSS") license for NetBird is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables Wiretrustee to safely commercialize our products our software. A CLA enables NetBird to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project. source, or to completely fork the project.
@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
This highlights only some of key terms of the CLA. It has no legal value and you should This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing. carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work <li>Grant of copyright license. You give NetBird permission to use your copyrighted work
in commercial products. in commercial products.
</li> </li>
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a <li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
license to use that patent including within commercial products. You also agree that you license to use that patent including within commercial products. You also agree that you
have permission to grant this license. have permission to grant this license.
</li> </li>
@ -45,7 +45,7 @@ more.
# Why require a CLA? # Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
products. products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
information from your account) and to fill in a few additional details such as your name and email address. We will only information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes. use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
require you to sign again. require you to sign again.
# Legal Terms and Agreement # Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose. your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee, NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
You reserve all right, title, and interest in and to Your Contributions. You reserve all right, title, and interest in and to Your Contributions.
1. Definitions. 1. Definitions.
``` ```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner "You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
``` ```
``` ```
"Contribution" shall mean any original work of authorship, including any modifications or additions to "Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in, an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by Wiretrustee (the "Work"). or documentation of, any of the products owned or managed by NetBird (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists, sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of, source code control systems, and issue tracking systems that are managed by, or on behalf of,
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution." marked or otherwise designated in writing by You as "Not a Contribution."
``` ```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee 2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works. perform, sublicense, and distribute Your Contributions and such derivative works.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free, to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
intellectual property that you create that includes your Contributions, you represent that you have received intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with Wiretrustee. with NetBird.
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of 5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from 7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including, any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]". conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these 8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect. representations inaccurate in any respect.

View File

@ -1,6 +1,6 @@
BSD 3-Clause License BSD 3-Clause License
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS Copyright (c) 2022 NetBird GmbH & AUTHORS
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,4 +1,9 @@
<div align="center"> <div align="center">
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: How to Achieve Zero Trust Access to Kubernetes — Effortlessly
</a>
<br/>
<br/>
<p align="center"> <p align="center">
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>

View File

@ -1,4 +1,4 @@
FROM alpine:3.21.0 FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@ -9,6 +9,7 @@ USER netbird:netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true ENV NB_DISABLE_DNS=true

View File

@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
// check if we need to generate JWT token // check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey) needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return return
}) })
if err != nil { if err != nil {

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd), Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd)) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatusOutput(cmd *cobra.Command) string { func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context())
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp)) statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
} }
return statusOutputString return statusOutputString
} }

View File

@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {

View File

@ -38,6 +38,7 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
) )
var ( var (
@ -73,6 +74,7 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",

View File

@ -9,7 +9,6 @@ import (
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
go func() { go func() {
// blocking // blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Debug(err) cmd.Printf("Error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
cancel() cancel()

View File

@ -2,107 +2,20 @@ package cmd
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime"
"sort"
"strings" "strings"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
type peerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
Total int `json:"total" yaml:"total"`
Connected int `json:"connected" yaml:"connected"`
Details []peerStateDetailOutput `json:"details" yaml:"details"`
}
type signalStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type managementStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutput struct {
Total int `json:"total" yaml:"total"`
Available int `json:"available" yaml:"available"`
Details []relayStateOutputDetail `json:"details" yaml:"details"`
}
type iceCandidateType struct {
Local string `json:"local" yaml:"local"`
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
var ( var (
detailFlag bool detailFlag bool
ipv4Flag bool ipv4Flag bool
@ -173,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
outputInformationHolder := convertToStatusOutputOverview(resp) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder) statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
case jsonFlag: case jsonFlag:
statusOutputString, err = parseToJSON(outputInformationHolder) statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
case yamlFlag: case yamlFlag:
statusOutputString, err = parseToYAML(outputInformationHolder) statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default: default:
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false) statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
} }
if err != nil { if err != nil {
@ -214,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
@ -251,175 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
} }
} }
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState()
managementOverview := managementStateOutput{
URL: managementState.GetURL(),
Connected: managementState.GetConnected(),
Error: managementState.Error,
}
signalState := pbFullStatus.GetSignalState()
signalOverview := signalStateOutput{
URL: signalState.GetURL(),
Connected: signalState.GetConnected(),
Error: signalState.Error,
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
func mapRelays(relays []*proto.RelayState) relayStateOutput {
var relayStateDetail []relayStateOutputDetail
var relaysAvailable int
for _, relay := range relays {
available := relay.GetAvailable()
relayStateDetail = append(relayStateDetail,
relayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
},
)
if available {
relaysAvailable++
}
}
return relayStateOutput{
Total: len(relays),
Available: relaysAvailable,
Details: relayStateDetail,
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
}
if isPeerConnected {
peersConnected++
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
}
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
peerState := peerStateDetailOutput{
IP: pbPeerState.GetIP(),
PubKey: pbPeerState.GetPubKey(),
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
},
IceCandidateEndpoint: iceCandidateType{
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
}
peersStateDetail = append(peersStateDetail, peerState)
}
sortPeersByIP(peersStateDetail)
peersOverview := peersStateOutput{
Total: len(peersStateDetail),
Connected: peersConnected,
Details: peersStateDetail,
}
return peersOverview
}
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
if len(peersStateDetail) > 0 {
sort.SliceStable(peersStateDetail, func(i, j int) bool {
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
return iAddr.Compare(jAddr) == -1
})
}
}
func parseInterfaceIP(interfaceIP string) string { func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP) ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil { if err != nil {
@ -427,452 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
} }
return fmt.Sprintf("%s\n", ip) return fmt.Sprintf("%s\n", ip)
} }
func parseToJSON(overview statusOutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func parseToYAML(overview statusOutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
networks,
networks,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
"%s\n"+
"%s",
parsedPeersString,
summary,
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
)
for _, peerState := range peers.Details {
localICE := "-"
if peerState.IceCandidateType.Local != "" {
localICE = peerState.IceCandidateType.Local
}
remoteICE := "-"
if peerState.IceCandidateType.Remote != "" {
remoteICE = peerState.IceCandidateType.Remote
}
localICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Local != "" {
localICEEndpoint = peerState.IceCandidateEndpoint.Local
}
remoteICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
}
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
" Public key: %s\n"+
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
peerState.Status,
peerState.ConnType,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
networks,
networks,
peerState.Latency.String(),
)
peersString += peerString
}
return peersString
}
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true
}
}
if len(ipsFilter) > 0 {
_, ok := ipsFilterMap[peerState.IP]
if !ok {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval
}
func toIEC(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@ -1,597 +1,11 @@
package cmd package cmd
import ( import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
) )
func init() {
loc, err := time.LoadLocation("UTC")
if err != nil {
panic(err)
}
time.Local = loc
}
var resp = &proto.StatusResponse{
Status: "Connected",
FullStatus: &proto.FullStatus{
Peers: []*proto.PeerState{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
Fqdn: "peer-1.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
RemoteIceCandidateEndpoint: "",
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Networks: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
Fqdn: "peer-2.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: &proto.SignalState{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: []*proto.RelayState{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
LocalPeerState: &proto.LocalPeerState{
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
}
var overview = statusOutputOverview{
Peers: peersStateOutput{
Total: 2,
Connected: 2,
Details: []peerStateDetailOutput{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
FQDN: "peer-1.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
},
IceCandidateEndpoint: iceCandidateType{
Local: "",
Remote: "",
},
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
FQDN: "peer-2.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
},
IceCandidateEndpoint: iceCandidateType{
Local: "10.0.0.1:10001",
Remote: "10.0.10.1:10002",
},
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
ManagementState: managementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: signalStateOutput{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: relayStateOutput{
Total: 2,
Available: 1,
Details: []relayStateOutputDetail{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
},
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := convertToStatusOutputOverview(resp)
assert.Equal(t, overview, convertedResult)
}
func TestSortingOfPeers(t *testing.T) {
peers := []peerStateDetailOutput{
{
IP: "192.168.178.104",
},
{
IP: "192.168.178.102",
},
{
IP: "192.168.178.101",
},
{
IP: "192.168.178.105",
},
{
IP: "192.168.178.103",
},
}
sortPeersByIP(peers)
assert.Equal(t, peers[3].IP, "192.168.178.104")
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := parseToJSON(overview)
//@formatter:off
expectedJSONString := `
{
"peers": {
"total": 2,
"connected": 2,
"details": [
{
"fqdn": "peer-1.awesome-domain.com",
"netbirdIp": "192.168.178.101",
"publicKey": "Pubkey1",
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"iceCandidateType": {
"local": "",
"remote": ""
},
"iceCandidateEndpoint": {
"local": "",
"remote": ""
},
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
"fqdn": "peer-2.awesome-domain.com",
"netbirdIp": "192.168.178.102",
"publicKey": "Pubkey2",
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
},
"iceCandidateEndpoint": {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null,
"networks": null
}
]
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
"error": ""
},
"signal": {
"url": "my-awesome-signal.com:443",
"connected": true,
"error": ""
},
"relays": {
"total": 2,
"available": 1,
"details": [
{
"uri": "stun:my-awesome-stun.com:3478",
"available": true,
"error": ""
},
{
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
"available": false,
"error": "context: deadline exceeded"
}
]
},
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
}`
// @formatter:on
var expectedJSON bytes.Buffer
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
assert.Equal(t, expectedJSON.String(), jsonString)
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := parseToYAML(overview)
expectedYAML :=
`peers:
total: 2
connected: 2
details:
- fqdn: peer-1.awesome-domain.com
netbirdIp: 192.168.178.101
publicKey: Pubkey1
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
url: my-awesome-management.com:443
connected: true
error: ""
signal:
url: my-awesome-signal.com:443
connected: true
error: ""
relays:
total: 2
available: 1
details:
- uri: stun:my-awesome-stun.com:3478
available: true
error: ""
- uri: turns:my-awesome-turn.com:443?transport=tcp
available: false
error: 'context: deadline exceeded'
netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
Public key: Pubkey1
Status: Connected
-- detail --
Connection type: P2P
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
Public key: Pubkey2
Status: Connected
-- detail --
Connection type: Relayed
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`
assert.Equal(t, expectedString, shortVersion)
}
func TestParsingOfIP(t *testing.T) { func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16" InterfaceIP := "192.168.178.123/16"
@ -599,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP) assert.Equal(t, "192.168.178.123\n", parsedIP)
} }
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

137
client/cmd/trace.go Normal file
View File

@ -0,0 +1,137 @@
package cmd
import (
"fmt"
"math/rand"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var traceCmd = &cobra.Command{
Use: "trace <direction> <source-ip> <dest-ip>",
Short: "Trace a packet through the firewall",
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
}
func init() {
debugCmd.AddCommand(traceCmd)
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
traceCmd.Flags().Uint16("sport", 0, "Source port")
traceCmd.Flags().Uint16("dport", 0, "Destination port")
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
}
func tracePacket(cmd *cobra.Command, args []string) error {
direction := strings.ToLower(args[0])
if direction != "in" && direction != "out" {
return fmt.Errorf("invalid direction: use 'in' or 'out'")
}
protocol := cmd.Flag("protocol").Value.String()
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
}
sport, err := cmd.Flags().GetUint16("sport")
if err != nil {
return fmt.Errorf("invalid source port: %v", err)
}
dport, err := cmd.Flags().GetUint16("dport")
if err != nil {
return fmt.Errorf("invalid destination port: %v", err)
}
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
if protocol != "icmp" {
if sport == 0 {
sport = uint16(rand.Intn(16383) + 49152)
}
if dport == 0 {
dport = uint16(rand.Intn(16383) + 49152)
}
}
var tcpFlags *proto.TCPFlags
if protocol == "tcp" {
syn, _ := cmd.Flags().GetBool("syn")
ack, _ := cmd.Flags().GetBool("ack")
fin, _ := cmd.Flags().GetBool("fin")
rst, _ := cmd.Flags().GetBool("rst")
psh, _ := cmd.Flags().GetBool("psh")
urg, _ := cmd.Flags().GetBool("urg")
tcpFlags = &proto.TCPFlags{
Syn: syn,
Ack: ack,
Fin: fin,
Rst: rst,
Psh: psh,
Urg: urg,
}
}
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
SourceIp: args[1],
DestinationIp: args[2],
Protocol: protocol,
SourcePort: uint32(sport),
DestinationPort: uint32(dport),
Direction: direction,
TcpFlags: tcpFlags,
IcmpType: &icmpType,
IcmpCode: &icmpCode,
})
if err != nil {
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
}
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
return nil
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
} else {
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
}
}
disposition := map[bool]string{
true: "\033[32mALLOWED\033[0m", // Green
false: "\033[31mDENIED\033[0m", // Red
}[resp.FinalDisposition]
cmd.Printf("\nFinal disposition: %s\n", disposition)
}

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -29,9 +30,16 @@ const (
interfaceInputType interfaceInputType
) )
const (
dnsLabelsFlag = "extra-dns-labels"
)
var ( var (
foregroundMode bool foregroundMode bool
upCmd = &cobra.Command{ dnsLabels []string
dnsLabelsValidated domain.List
upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start Netbird client", Short: "install, login and start Netbird client",
RunE: upFunc, RunE: upFunc,
@ -48,6 +56,15 @@ func init() {
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@ -66,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
if hostName != "" { if hostName != "" {
@ -97,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs, NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList, ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
} }
if cmd.Flag(enableRosenpassFlag).Changed { if cmd.Flag(enableRosenpassFlag).Changed {
@ -160,6 +183,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DisableFirewall = &disableFirewall ic.DisableFirewall = &disableFirewall
} }
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return err
@ -185,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus() r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r) connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run() return connectClient.Run(nil)
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@ -235,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@ -290,6 +319,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.DisableFirewall = &disableFirewall loginRequest.DisableFirewall = &disableFirewall
} }
if cmd.Flag(blockLANAccessFlag).Changed {
loginRequest.BlockLanAccess = &blockLANAccess
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse
@ -421,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil return parsed, nil
} }
func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)
if len(labels) == 0 {
return domains, nil
}
domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
return domains, nil
}
func isValidAddrPort(input string) bool { func isValidAddrPort(input string) bool {
if input == "" { if input == "" {
return true return true

167
client/embed/doc.go Normal file
View File

@ -0,0 +1,167 @@
// Package embed provides a way to embed the NetBird client directly
// into Go programs without requiring a separate NetBird client installation.
package embed
// Basic Usage:
//
// client, err := embed.New(embed.Options{
// DeviceName: "my-service",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// })
// if err != nil {
// log.Fatal(err)
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// Complete HTTP Server Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "log"
// "net/http"
// "os"
// "os/signal"
// "syscall"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-server",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP server
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
// fmt.Fprintf(w, "Hello from netbird!")
// })
//
// // Listen on netbird network
// l, err := client.ListenTCP(":8080")
// if err != nil {
// log.Fatal(err)
// }
//
// server := &http.Server{Handler: mux}
// go func() {
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
// log.Printf("HTTP server error: %v", err)
// }
// }()
//
// log.Printf("HTTP server listening on netbird network port 8080")
//
// // Handle shutdown
// stop := make(chan os.Signal, 1)
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// <-stop
//
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := server.Shutdown(shutdownCtx); err != nil {
// log.Printf("HTTP shutdown error: %v", err)
// }
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// Complete HTTP Client Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "io"
// "log"
// "os"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-client",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP client that uses netbird network
// httpClient := client.NewHTTPClient()
// httpClient.Timeout = 10 * time.Second
//
// // Make request to server in netbird network
// target := os.Getenv("NB_TARGET")
// resp, err := httpClient.Get(target)
// if err != nil {
// log.Fatal(err)
// }
// defer resp.Body.Close()
//
// // Read and print response
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Printf("Response from server: %s\n", string(body))
//
// // Clean shutdown
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// The package provides several methods for network operations:
// - Dial: Creates outbound connections
// - ListenTCP: Creates TCP listeners
// - ListenUDP: Creates UDP listeners
//
// By default, the embed package uses userspace networking mode, which doesn't
// require root/admin privileges. For production deployments, consider setting
// appropriate config and state paths for persistence.

296
client/embed/embed.go Normal file
View File

@ -0,0 +1,296 @@
package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
run <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-run:
if err != nil {
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}

View File

@ -14,13 +14,13 @@ import (
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface) fm, err := uspfilter.Create(iface, disableServerRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager) fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm) return createUserspaceFirewall(iface, fm, disableServerRoutes)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface) fm, err := createFW(iface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create firewall: %s", err) return nil, fmt.Errorf("create firewall: %s", err)
@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
} else { } else {
fm, errUsp = uspfilter.Create(iface) fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
} }
if errUsp != nil { if errUsp != nil {

View File

@ -1,6 +1,8 @@
package firewall package firewall
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -10,4 +12,6 @@ type IFaceMapper interface {
Address() device.WGAddress Address() device.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
} }

View File

@ -3,7 +3,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net"
"strconv" "slices"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
@ -19,8 +19,7 @@ const (
tableName = "filter" tableName = "filter"
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string type aclEntries map[string][][]string
@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var dPortVal, sPortVal string chain := chainNameInputRules
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
var chain string ipsetName = transformIPsetName(ipsetName, sPort, dPort)
if direction == firewall.RuleDirectionOUT { specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) mangleSpecs := slices.Clone(specs)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList) m.ipsetStore.addIpList(ipsetName, ipList)
} }
ok, err := m.iptablesClient.Exists("filter", chain, specs...) ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err) return nil, fmt.Errorf("failed to check rule: %w", err)
} }
@ -145,16 +138,22 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
return nil, err return nil, err
} }
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{ rule := &Rule{
ruleID: uuid.New().String(), ruleID: uuid.New().String(),
specs: specs, specs: specs,
ipsetName: ipsetName, mangleSpecs: mangleSpecs,
ip: ip.String(), ipsetName: ipsetName,
chain: chain, ip: ip.String(),
chain: chain,
} }
m.updateState() m.updateState()
@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
} }
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
m.updateState() m.updateState()
return nil return nil
@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error { func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil { if err != nil {
log.Debugf("failed to list chains: %s", err) log.Debugf("failed to list chains: %s", err)
return err return err
@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
return err return err
} }
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@ -329,8 +307,6 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished() established := getConntrackEstablished()
@ -346,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
@ -390,42 +359,26 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" { if ip.String() == "0.0.0.0" {
matchByIP = false matchByIP = false
} }
switch direction {
case firewall.RuleDirectionIN: if matchByIP {
if matchByIP { if ipsetName != "" {
if ipsetName != "" { specs = append(specs, "-m", "set", "--set", ipsetName, "src")
specs = append(specs, "-m", "set", "--set", ipsetName, "src") } else {
} else { specs = append(specs, "-s", ip.String())
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
} }
} }
if protocol != "all" { if protocol != "all" {
specs = append(specs, "-p", protocol) specs = append(specs, "-p", protocol)
} }
if sPort != "" { specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, "--sport", sPort) specs = append(specs, applyPort("--dport", dPort)...)
} return specs
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
} }
func actionToStr(action firewall.Action) string { func actionToStr(action firewall.Action) string {
@ -435,15 +388,15 @@ func actionToStr(action firewall.Action) string {
return "DROP" return "DROP"
} }
func transformIPsetName(ipsetName string, sPort, dPort string) string { func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
switch { switch {
case ipsetName == "": case ipsetName == "":
return "" return ""
case sPort != "" && dPort != "": case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" return ipsetName + "-sport-dport"
case sPort != "": case sPort != nil:
return ipsetName + "-sport" return ipsetName + "-sport"
case dPort != "": case dPort != nil:
return ipsetName + "-dport" return ipsetName + "-dport"
default: default:
return ipsetName return ipsetName

View File

@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, _ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
"all", "all",
nil, nil,
nil, nil,
firewall.RuleDirectionIN,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"", "",
@ -215,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, IsRange: true,
Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Reset(nil)
@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
} }
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
} }
rule := genRouteFilteringRuleSpec(params) rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
}
if err != nil {
return nil, fmt.Errorf("add route rule: %v", err) return nil, fmt.Errorf("add route rule: %v", err)
} }
@ -590,10 +599,10 @@ func applyPort(flag string, port *firewall.Port) []string {
if len(port.Values) > 1 { if len(port.Values) > 1 {
portList := make([]string, len(port.Values)) portList := make([]string, len(port.Values))
for i, p := range port.Values { for i, p := range port.Values {
portList[i] = strconv.Itoa(p) portList[i] = strconv.Itoa(int(p))
} }
return []string{"-m", "multiport", flag, strings.Join(portList, ",")} return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
} }
return []string{flag, strconv.Itoa(port.Values[0])} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }

View File

@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []int{80}}, dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}, },
destination: netip.MustParsePrefix("10.0.0.0/8"), destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop, action: firewall.ActionDrop,
@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"), destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}}, sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop, action: firewall.ActionDrop,
expectSet: false, expectSet: false,
@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"), destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}}, dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,

View File

@ -5,9 +5,10 @@ type Rule struct {
ruleID string ruleID string
ipsetName string ipsetName string
specs []string specs []string
ip string mangleSpecs []string
chain string ip string
chain string
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@ -69,7 +69,6 @@ type Manager interface {
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
direction RuleDirection,
action Action, action Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -100,6 +99,12 @@ type Manager interface {
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@ -30,7 +30,7 @@ type Port struct {
IsRange bool IsRange bool
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int Values []uint16
} }
// String interface implementation // String interface implementation
@ -40,7 +40,11 @@ func (p *Port) String() string {
if ports != "" { if ports != "" {
ports += "," ports += ","
} }
ports += strconv.Itoa(port) ports += strconv.Itoa(int(port))
} }
if p.IsRange {
ports = "range:" + ports
}
return ports return ports
} }

View File

@ -2,9 +2,9 @@ package nftables
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -22,8 +22,7 @@ import (
const ( const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules" chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
@ -45,9 +44,9 @@ type AclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@ -89,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -104,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if r.nftSet == nil { if r.nftSet == nil {
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID()) delete(m.rules, r.GetRuleID())
return m.rConn.Flush() return m.rConn.Flush()
} }
ips, ok := m.ipsetStore.ips(r.nftSet.Name) ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok { if !ok {
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID()) delete(m.rules, r.GetRuleID())
return m.rConn.Flush() return m.rConn.Flush()
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil { if err != nil {
@ -156,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
err = m.rConn.Flush() if r.mangleRule != nil {
if err != nil { if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err return err
} }
@ -214,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn, Exprs: expIn,
}) })
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@ -260,25 +239,33 @@ func (m *AclManager) Flush() error {
return err return err
} }
if err := m.refreshRuleHandles(m.chainInputRules); err != nil { if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil { log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, nftRule: r.nftRule,
r.nftSet, mangleRule: r.mangleRule,
r.ruleID, nftSet: r.nftSet,
ip, ruleID: r.ruleID,
ip: ip,
}, nil }, nil
} }
@ -310,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
if !bytes.HasPrefix(anyIP, rawIP) { if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
addrOffset := uint32(12) addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
@ -342,73 +326,100 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} }
} }
if sPort != nil && len(sPort.Values) != 0 { expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, expressions = append(expressions, applyPort(dPort, false)...)
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 { mainExpressions := slices.Clone(expressions)
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
switch action { switch action {
case firewall.ActionAccept: case firewall.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop: case firewall.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(strings.Join([]string{ruleId, comment}, " "))
var chain *nftables.Chain chain := m.chainInputRules
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: mainExpressions,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{ rule := &Rule{
nftRule: nftRule, nftRule: nftRule,
nftSet: ipset, mangleRule: m.createPreroutingRule(expressions, userData),
ruleID: ruleId, nftSet: ipset,
ip: ip, ruleID: ruleId,
ip: ip,
} }
m.rules[ruleId] = rule m.rules[ruleId] = rule
if ipset != nil { if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name) m.ipsetStore.AddReferenceToIpset(ipset.Name)
} }
return rule, nil return rule, nil
} }
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
return m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
@ -419,15 +430,6 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
m.chainInputRules = chain m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@ -461,7 +463,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{ m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
@ -469,8 +471,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
}) })
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
@ -480,43 +480,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
return nil return nil
} }
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{ m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
@ -532,8 +495,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictAccept,
Chain: m.chainInputRules.Name,
}, },
}, },
}) })
@ -680,6 +642,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
for i := 0; ; i++ { for i := 0; ; i++ {
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") { if !strings.Contains(err.Error(), "busy") {
return return
} }
@ -696,7 +659,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
return return
} }
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil { if m.workTable == nil || chain == nil {
return nil return nil
} }
@ -713,22 +676,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
split := bytes.Split(rule.UserData, []byte(" ")) split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])] r, ok := m.rules[string(split[0])]
if ok { if ok {
*r.nftRule = *rule if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
} }
} }
return nil return nil
} }
func generatePeerRuleId( func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
ip net.IP, rulesetID := ":"
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil { if sPort != nil {
rulesetID += sPort.String() rulesetID += sPort.String()
} }
@ -744,12 +704,6 @@ func generatePeerRuleId(
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, n+"\x00") copy(b, n+"\x00")

View File

@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -312,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
return nil return nil
} }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
// //
// Method also get all rules after flush and refreshes handle values in the rulesets // Method also get all rules after flush and refreshes handle values in the rulesets

View File

@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@ -116,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
} }
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering( _, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
netip.MustParsePrefix("10.1.0.0/24"), netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []int{443}}, &fw.Port{Values: []uint16{443}},
fw.ActionAccept, fw.ActionAccept,
) )
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
@ -329,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
for i := range got {
if _, isCounter := got[i].(*expr.Counter); isCounter {
_, wantIsCounter := want[i].(*expr.Counter)
require.True(t, wantIsCounter, "expected Counter at index %d", i)
continue
}
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
}
}

View File

@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
} }
rule = r.conn.AddRule(rule) // Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// TODO: Insert after the established rule
rule = r.conn.InsertRule(rule)
} else {
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule)) log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
@ -956,12 +962,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpGte, Op: expr.CmpOpGte,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])), Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpLte, Op: expr.CmpOpLte,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])), Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
}, },
) )
} else { } else {
@ -980,7 +986,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
exprs = append(exprs, &expr.Cmp{ exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)), Data: binaryutil.BigEndian.PutUint16(p),
}) })
} }
} }

View File

@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []int{80}}, dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}, },
destination: netip.MustParsePrefix("10.0.0.0/8"), destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop, action: firewall.ActionDrop,
@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"), destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}}, sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop, action: firewall.ActionDrop,
expectSet: false, expectSet: false,
@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"), destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}}, dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,

View File

@ -8,10 +8,11 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
nftRule *nftables.Rule nftRule *nftables.Rule
nftSet *nftables.Set mangleRule *nftables.Rule
ruleID string nftSet *nftables.Set
ip net.IP ruleID string
ip net.IP
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@ -3,6 +3,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {

View File

@ -1,9 +1,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {

View File

@ -0,0 +1,16 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@ -10,12 +10,11 @@ import (
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP SourceIP net.IP
DestIP net.IP DestIP net.IP
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())

View File

@ -3,8 +3,14 @@ package conntrack
import ( import (
"net" "net"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
func BenchmarkIPOperations(b *testing.B) { func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
}) })
} }
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs

View File

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -42,12 +45,13 @@ type ICMPTracker struct {
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
// TrackOutbound records an outbound ICMP Echo Request // TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq) key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
ID: id, ID: id,
Sequence: seq, Sequence: seq,
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
} }
t.mutex.Unlock() t.mutex.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false return false
} }
@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
return false return false
} }
return conn.IsEstablished() && return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id && conn.ID == id &&
conn.Sequence == seq conn.Sequence == seq
@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
} }
} }
} }

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -5,7 +5,10 @@ package conntrack
import ( import (
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -61,12 +64,24 @@ type TCPConnKey struct {
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState State TCPState
established atomic.Bool
sync.RWMutex sync.RWMutex
} }
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -76,8 +91,9 @@ type TCPTracker struct {
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock // Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
}, },
State: TCPStateNew, State: TCPStateNew,
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(false) conn.established.Store(false)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
} }
t.mutex.Unlock() t.mutex.Unlock()
@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.Lock() conn.Lock()
t.updateState(conn, flags, true) t.updateState(conn, flags, true)
conn.Unlock() conn.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
return return
} }
@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateTimeWait: case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed // Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine // This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
} }

View File

@ -9,7 +9,7 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout) tracker = NewTCPTracker(DefaultTCPTimeout, logger)
tt.test(t) tt.test(t)
}) })
} }
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections

View File

@ -4,6 +4,8 @@ import (
"net" "net"
"sync" "sync"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -20,6 +22,7 @@ type UDPConnTrack struct {
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
type UDPTracker struct { type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -29,12 +32,13 @@ type UDPTracker struct {
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
DestPort: dstPort, DestPort: dstPort,
}, },
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
} }
t.mutex.Unlock() t.mutex.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
return conn.IsEstablished() && return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort && conn.DestPort == srcPort &&
conn.SourcePort == dstPort conn.SourcePort == dstPort
@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
} }
} }
} }

View File

@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout) tracker := NewUDPTracker(tt.timeout, logger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second) tracker := NewUDPTracker(1*time.Second, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
logger: logger,
} }
// Start cleanup routine // Start cleanup routine
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -0,0 +1,81 @@
package forwarder
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}

View File

@ -0,0 +1,166 @@
package forwarder
import (
"context"
"fmt"
"net"
"runtime"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
defaultReceiveWindow = 32768
defaultMaxInFlight = 1024
iosReceiveWindow = 16384
iosMaxInFlight = 256
)
type Forwarder struct {
logger *nblog.Logger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
maxInFlight := defaultMaxInFlight
if runtime.GOOS == "ios" {
receiveWindow = iosReceiveWindow
maxInFlight = iosMaxInFlight
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@ -0,0 +1,109 @@
package forwarder
import (
"context"
"net"
"time"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
return true
}

View File

@ -0,0 +1,90 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection %v", id)
go f.proxyTCP(id, inConn, outConn, ep)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
return
}
}

View File

@ -0,0 +1,284 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
udpTimeout = 30 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
conn.ep.Close()
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
var idleConns []idleConn
f.RLock()
for id, conn := range f.conns {
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, idleConn{id, conn})
}
}
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
}
idle.conn.ep.Close()
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
}
}
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id)
return
}
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
ep: ep,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id)
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
return
}
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
for {
if ctx.Err() != nil {
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@ -0,0 +1,134 @@
package uspfilter
import (
"fmt"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
return
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
}
return false
}

View File

@ -0,0 +1,270 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr iface.WGAddress
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: iface.WGAddress{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: net.ParseIP("fe80::1"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() iface.WGAddress {
return tt.setupAddr
},
}
err := manager.UpdateLocalIPs(mock)
require.NoError(t, err)
result := manager.IsLocalIP(tt.testIP)
require.Equal(t, tt.expected, result)
})
}
}
func TestLocalIPManager_AllInterfaces(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{}
// Get actual local interfaces
interfaces, err := net.Interfaces()
require.NoError(t, err)
var tests []struct {
ip string
expected bool
}
// Add all local interface IPs to test cases
for _, iface := range interfaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip4 := ip.To4(); ip4 != nil {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip4.String(),
expected: true,
})
}
}
}
// Add some external IPs as negative test cases
externalIPs := []string{
"8.8.8.8",
"1.1.1.1",
"208.67.222.222",
}
for _, ip := range externalIPs {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip,
expected: false,
})
}
require.NotEmpty(t, tests, "No test cases generated")
err = manager.UpdateLocalIPs(mock)
require.NoError(t, err)
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@ -0,0 +1,196 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second
)
// Level represents log severity
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
buffer *ringBuffer
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
log.Debugf("New uspfilter logger created with loglevel %v", level)
l.wg.Add(1)
go l.worker()
return l
}
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
}
*buf = append(*buf, '\n')
}
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
func (l *Logger) Warn(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
func (l *Logger) Info(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
func (l *Logger) Debug(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
func (l *Logger) Trace(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
// worker periodically flushes the buffer
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
return
case <-ticker.C:
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
_, _ = l.output.Write(buf[:n])
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
done := make(chan struct{})
l.closeOnce.Do(func() {
close(l.shutdown)
})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@ -0,0 +1,85 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@ -2,22 +2,22 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// Rule to handle management of rules // PeerRule to handle management of rules
type Rule struct { type PeerRule struct {
id string id string
ip net.IP ip net.IP
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
direction firewall.RuleDirection sPort *firewall.Port
sPort uint16 dPort *firewall.Port
dPort uint16
drop bool drop bool
comment string comment string
@ -25,6 +25,21 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *PeerRule) GetRuleID() string {
return r.id
}
type RouteRule struct {
id string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// GetRuleID returns the rule id
func (r *RouteRule) GetRuleID() string {
return r.id return r.id
} }

View File

@ -0,0 +1,390 @@
package uspfilter
import (
"fmt"
"net"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
)
type PacketStage int
const (
StageReceived PacketStage = iota
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
}[s]
}
type ForwarderAction struct {
Action string
RemoteAddr string
Error error
}
type TraceResult struct {
Timestamp time.Time
Stage PacketStage
Message string
Allowed bool
ForwarderAction *ForwarderAction
}
type PacketTrace struct {
SourceIP net.IP
DestinationIP net.IP
Protocol string
SourcePort uint16
DestinationPort uint16
Direction fw.RuleDirection
Results []TraceResult
}
type TCPState struct {
SYN bool
ACK bool
FIN bool
RST bool
PSH bool
URG bool
}
type PacketBuilder struct {
SrcIP net.IP
DstIP net.IP
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
ICMPType uint8
ICMPCode uint8
Direction fw.RuleDirection
PayloadSize int
TCPState *TCPState
}
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
})
}
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
ForwarderAction: action,
})
}
func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
transportLayer, err := p.buildTransportLayer(ip)
if err != nil {
return nil, err
}
pktLayers = append(pktLayers, transportLayer...)
if p.PayloadSize > 0 {
payload := make([]byte, p.PayloadSize)
pktLayers = append(pktLayers, gopacket.Payload(payload))
}
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
}
}
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ip)
case "udp":
return p.buildUDPLayer(ip)
case "icmp":
return p.buildICMPLayer()
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
Window: 65535,
SYN: p.TCPState != nil && p.TCPState.SYN,
ACK: p.TCPState != nil && p.TCPState.ACK,
FIN: p.TCPState != nil && p.TCPState.FIN,
RST: p.TCPState != nil && p.TCPState.RST,
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
icmp.Id = uint16(1)
icmp.Seq = uint16(1)
}
return []gopacket.SerializableLayer{icmp}, nil
}
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
return nil, fmt.Errorf("serialize packet: %w", err)
}
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol) int {
switch protocol {
case fw.ProtocolTCP:
return int(layers.IPProtocolTCP)
case fw.ProtocolUDP:
return int(layers.IPProtocolUDP)
case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4)
default:
return 0
}
}
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
packetData, err := builder.Build()
if err != nil {
return nil, fmt.Errorf("build packet: %w", err)
}
return m.TracePacket(packetData, builder.Direction), nil
}
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
// Extract base packet info
srcIP, dstIP := m.extractIPs(d)
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:
trace.Protocol = "TCP"
trace.SourcePort = uint16(d.tcp.SrcPort)
trace.DestinationPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
trace.Protocol = "UDP"
trace.SourcePort = uint16(d.udp.SrcPort)
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
if direction == fw.RuleDirectionOUT {
return m.traceOutbound(packetData, trace)
}
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
trace.AddResult(StageConntrack, msg, true)
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
return true
}
trace.AddResult(StageConntrack, msg, false)
return false
}
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
msg := "Matched existing connection state"
switch d.decoded[1] {
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
flags&conntrack.TCPSyn != 0,
flags&conntrack.TCPAck != 0,
flags&conntrack.TCPRst != 0,
flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
}
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
msg := "Allowed by peer ACL rules"
if blocked {
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
}
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
return true
}
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
trace.AddResult(StageForwarding, "Forwarding via native router", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
msg := "Allowed by route ACLs"
if !allowed {
msg = "Blocked by route ACLs"
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
return trace
}
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
fwdAction := &ForwarderAction{
Action: action,
RemoteAddr: remoteAddr,
}
trace.AddResultWithForwarder(StageForwarding,
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
}
return trace
}

View File

@ -1,11 +1,14 @@
package uspfilter package uspfilter
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@ -14,44 +17,83 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
var ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule type RuleSet map[string]PeerRule
type RouteRules []RouteRule
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1
}
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(a.id, b.id)
})
}
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules map[string]RuleSet // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
stateful bool // indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
netstack bool
// indicates whether we forward local traffic to the native stack
localForwarding bool
localipmanager *localIPManager
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
logger *nblog.Logger
} }
// decoder for packages // decoder for packages
@ -68,22 +110,44 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface) return create(iface, nil, disableServerRoutes)
} }
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
mgr, err := create(iface) if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mgr.nativeFirewall = nativeFirewall
return mgr, nil return mgr, nil
} }
func create(iface IFaceMapper) (*Manager, error) { func parseCreateEnv() (bool, bool) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) var disableConntrack, enableLocalForwarding bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
}
}
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
}
return disableConntrack, enableLocalForwarding
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
@ -99,52 +163,182 @@ func create(iface IFaceMapper) (*Manager, error) {
return d return d
}, },
}, },
outgoingRules: make(map[string]RuleSet), nativeFirewall: nativeFirewall,
incomingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
wgIface: iface, incomingRules: make(map[string]RuleSet),
stateful: !disableConntrack, wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
} }
// Only initialize trackers if stateful mode is enabled
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, err return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering(
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix,
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
); err != nil {
return fmt.Errorf("block wg nte : %w", err)
}
// TODO: Block networks that we're a client of
return nil
}
func (m *Manager) determineRouting() error {
var disableUspRouting, forceUserspaceRouter bool
var err error
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
disableUspRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
}
}
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
forceUserspaceRouter, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
}
}
switch {
case disableUspRouting:
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled = true
m.nativeRouter = true
log.Info("native routing is enabled")
default:
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing enabled by default")
}
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}
return nil
}
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder != nil {
return nil
}
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled = false
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
if err != nil {
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder = forwarder
log.Debug("forwarder initialized")
return nil
}
func (m *Manager) Init(*statemanager.Manager) error { func (m *Manager) Init(*statemanager.Manager) error {
return nil return nil
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { return true
return false
} else {
return true
}
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.AddNatRule(pair)
} }
return m.nativeFirewall.AddNatRule(pair)
// userspace routed packets are always SNATed to the inbound direction
// TODO: implement outbound SNAT
return nil
} }
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.RemoveNatRule(pair)
} }
return m.nativeFirewall.RemoveNatRule(pair) return nil
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@ -156,17 +350,15 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, _ string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment, comment: comment,
} }
@ -179,13 +371,8 @@ func (m *Manager) AddPeerFiltering(
r.matchByIP = false r.matchByIP = false
} }
if sPort != nil && len(sPort.Values) == 1 { r.sPort = sPort
r.sPort = uint16(sPort.Values[0]) r.dPort = dPort
}
if dPort != nil && len(dPort.Values) == 1 {
r.dPort = uint16(dPort.Values[0])
}
switch proto { switch proto {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@ -202,33 +389,64 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()]; !ok {
if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(RuleSet)
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
if m.nativeFirewall == nil { sources []netip.Prefix,
return nil, errRouteNotSupported destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
} }
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
id: ruleID,
sources: sources,
destination: destination,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort()
return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.DeleteRouteRule(rule)
} }
return m.nativeFirewall.DeleteRouteRule(rule)
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.GetRuleID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
return nil
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@ -236,24 +454,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
r, ok := rule.(*Rule) r, ok := rule.(*PeerRule)
if !ok { if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if r.direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
_, ok := m.incomingRules[r.ip.String()][r.id] return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id)
return nil return nil
} }
@ -276,10 +485,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules) return m.dropFilter(packetData)
}
// UpdateLocalIPs updates the list of local IPs
func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@ -300,18 +513,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
return false return false
} }
// Always process UDP hooks // Track all protocols if stateful mode is enabled
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful { if m.stateful {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP) m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@ -319,6 +525,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
} }
// Process UDP hooks even if stateful mode is disabled
if d.decoded[1] == layers.LayerTypeUDP {
return m.checkUDPHooks(d, dstIP, packetData)
}
return false return false
} }
@ -380,7 +591,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists { if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules { for _, rule := range rules {
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.udpHook(packetData) return rule.udpHook(packetData)
} }
} }
@ -400,10 +611,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
} }
} }
// dropFilter implements filtering logic for incoming packets // dropFilter implements filtering logic for incoming packets.
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // If it returns true, the packet should be dropped.
// TODO: Disable router if --disable-server-router is set func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@ -416,39 +626,129 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
if !m.isWireguardTraffic(srcIP, dstIP) { // For all inbound traffic, first check if it matches a tracked connection.
return false // This must happen before any other filtering because the packets are statefully tracked.
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false return false
} }
return m.applyRules(srcIP, packetData, rules, d) if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
}
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
srcIP, dstIP)
return true
}
// if running in netstack mode we need to pass this to the forwarder
if m.netstack {
return m.handleNetstackLocalTraffic(packetData)
}
return false
}
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
if !m.localForwarding {
// pass to virtual tcp/ip stack to be picked up by listeners
return false
}
if m.forwarder == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
// don't process this packet further
return true
}
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
// Pass to native stack if native router is enabled or forced
if m.nativeRouter {
return false
}
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
return true
}
// Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
return true
}
func getProtocolFromPacket(d *decoder) firewall.Protocol {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return firewall.ProtocolTCP
case layers.LayerTypeUDP:
return firewall.ProtocolUDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP
default:
return firewall.ProtocolALL
}
}
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
default:
return 0, 0
}
} }
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err) m.logger.Trace("couldn't decode packet, err: %s", err)
return false return false
} }
if len(d.decoded) < 2 { if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet") m.logger.Trace("packet doesn't have network and transport layers")
return false return false
} }
return true return true
} }
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
@ -483,7 +783,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
return false return false
} }
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { // isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
func (m *Manager) isSpecialICMP(d *decoder) bool {
if d.decoded[1] != layers.LayerTypeICMPv4 {
return false
}
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if m.isSpecialICMP(d) {
return false
}
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
return filter return filter
} }
@ -500,7 +815,24 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
return true return true
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
if rulePort == nil {
return true
}
if rulePort.IsRange {
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
}
for _, p := range rulePort.Values {
if p == packetPort {
return true
}
}
return false
}
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && !ip.Equal(rule.ip) {
@ -517,13 +849,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 { if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop, true return rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@ -533,13 +859,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
return rule.udpHook(packetData), true return rule.udpHook(packetData), true
} }
if rule.sPort == 0 && rule.dPort == 0 { if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
@ -549,6 +869,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
return false, false return false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
return rule.action == firewall.ActionAccept
}
}
return false
}
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
return false
}
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
}
}
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
}
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network m.wgNetwork = network
@ -560,13 +925,12 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool, in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string { ) string {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: dPort, dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
@ -577,14 +941,13 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock() m.mutex.Lock()
if in { if in {
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule) m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]Rule) m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip.String()][r.id] = r
} }
@ -596,21 +959,62 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, arr := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }
for _, arr := range m.outgoingRules { for _, arr := range m.outgoingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(level log.Level) {
if m.logger != nil {
m.logger.SetLevel(nblog.Level(level))
}
}
func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.determineRouting()
}
func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.forwarder == nil {
return nil
}
m.routingEnabled = false
m.nativeRouter = false
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
m.forwarder.Stop()
m.forwarder = nil
log.Debug("forwarder stopped")
return nil
}

View File

@ -1,9 +1,12 @@
//go:build uspbench
package uspfilter package uspfilter
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -91,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@ -112,9 +115,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0] ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}}, &fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []int{80}}, &fw.Port{Values: []uint16{80}},
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") fw.ActionAccept, "", "explicit return")
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@ -126,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") fw.ActionDrop, "", "default drop")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, manager.incomingRules) manager.dropFilter(testIn)
} }
}) })
} }
@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack)
@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -588,9 +591,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []uint16{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], manager.incomingRules) manager.dropFilter(inPackets[connIdx])
} }
}) })
} }
@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -679,9 +682,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []uint16{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules) manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules) manager.dropFilter(p.response)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules) manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer, manager.incomingRules) manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient)
} }
}) })
@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -797,9 +800,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []uint16{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules) manager.dropFilter(inPackets[connIdx])
} }
}) })
}) })
@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -884,9 +887,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []uint16{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules) manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules) manager.dropFilter(p.response)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules) manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer, manager.incomingRules) manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient)
} }
}) })
@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes() return buf.Bytes()
} }
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
// Add several route rules to simulate real-world scenario
rules := []struct {
sources []netip.Prefix
dest netip.Prefix
proto fw.Protocol
port *fw.Port
}{
{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
port: &fw.Port{Values: []uint16{80, 443}},
},
{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("10.0.0.0/8"),
},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP,
},
{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.0.0/16"),
proto: fw.ProtocolUDP,
port: &fw.Port{Values: []uint16{53}},
},
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil {
b.Fatal(err)
}
}
// Test cases that exercise different matching scenarios
cases := []struct {
srcIP string
dstIP string
proto fw.Protocol
dstPort uint16
}{
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -9,17 +9,38 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil
}
return i.GetWGDeviceFunc()
}
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
if i.GetDeviceFunc == nil {
return nil
}
return i.GetDeviceFunc()
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -69,12 +90,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []uint16{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -96,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -104,38 +124,16 @@ func TestManagerDeleteRule(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []uint16{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule 2"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
@ -189,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
@ -217,18 +215,14 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
if tt.dPort != addedRule.dPort { if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort) t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return return
} }
if layers.LayerTypeUDP != addedRule.protoLayer { if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return return
} }
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil { if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
@ -242,7 +236,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -250,12 +244,11 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []uint16{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -275,9 +268,18 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -289,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
direction := fw.RuleDirectionOUT
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -327,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), m.outgoingRules) { if m.dropFilter(buf.Bytes()) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@ -346,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface) manager, err := Create(iface, false)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@ -392,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@ -400,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Reset(nil))
}() }()
@ -478,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, false)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@ -492,12 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }
@ -509,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@ -518,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@ -639,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) drop = manager.dropFilter(inboundBuf.Bytes())
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@ -710,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) drop = manager.dropFilter(testBuf.Bytes())
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
var localAddrsForUnspecified []net.Addr mux := &UDPMuxDefault{
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize) return newBufferHolder(receiveMTU + maxAddrSize)
}, },
}, },
localAddrsForUnspecified: localAddrsForUnspecified,
} }
mux.updateLocalAddresses()
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
m.mu.Lock()
m.localAddrsForUnspecified = localAddrsForUnspecified
m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this UDPMuxDefault
@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 { if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified return slices.Clone(m.localAddrsForUnspecified)
} }
return []net.Addr{m.LocalAddr()} return []net.Addr{m.LocalAddr()}
@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
m.mu.Unlock()
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String()) return nil, fmt.Errorf("invalid address %s", addr.String())
} }

View File

@ -2,5 +2,5 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee // WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "wt0"

View File

@ -2,5 +2,5 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee // WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "utun100" const WgInterfaceDefault = "utun100"

View File

@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
func getFwmark() int { func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark return nbnet.NetbirdFwmark
} }
return 0 return 0

View File

@ -3,6 +3,10 @@
package iface package iface
import ( import (
"golang.zx2c4.com/wireguard/tun/netstack"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -15,4 +19,6 @@ type WGTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
} }

View File

@ -9,6 +9,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -63,7 +64,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
t.filteredDevice = newDeviceFilter(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
@ -130,6 +131,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
func routesToString(routes []string) string { func routesToString(routes []string) string {
return strings.Join(routes, ";") return strings.Join(routes, ";")
} }

View File

@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -117,6 +118,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
@ -138,3 +144,7 @@ func (t *TunDevice) assignAddr() error {
} }
return nil return nil
} }
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@ -10,6 +10,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -64,7 +65,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
t.filteredDevice = newDeviceFilter(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface") log.Debug("Attaching to interface")
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
@ -131,3 +132,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error {
func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@ -9,6 +9,8 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -33,8 +35,6 @@ type TunKernelDevice struct {
} }
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@ -153,6 +153,11 @@ func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }
// Device returns the wireguard device, not applicable for kernel devices
func (t *TunKernelDevice) Device() *device.Device {
return nil
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil return nil
} }
@ -161,3 +166,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address) return t.link.assignAddr(t.address)
} }
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}

View File

@ -8,10 +8,12 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
@ -25,9 +27,11 @@ type TunNetstackDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
@ -43,13 +47,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
} }
func (t *TunNetstackDevice) Create() (WGConfigurer, error) { func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface") log.Info("create nbnetstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create() // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.filteredDevice = newDeviceFilter(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
t.net = net
t.device = device.NewDevice( t.device = device.NewDevice(
t.filteredDevice, t.filteredDevice,
@ -117,3 +127,12 @@ func (t *TunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}

View File

@ -4,12 +4,11 @@ package device
import ( import (
"fmt" "fmt"
"os"
"runtime"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -32,8 +31,6 @@ type USPDevice struct {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
checkUser()
return &USPDevice{ return &USPDevice{
name: name, name: name,
address: address, address: address,
@ -128,6 +125,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *USPDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error { func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)
@ -135,11 +137,6 @@ func (t *USPDevice) assignAddr() error {
return link.assignAddr(t.address) return link.assignAddr(t.address)
} }
func checkUser() { func (t *USPDevice) GetNet() *netstack.Net {
if runtime.GOOS == "freebsd" { return nil
euid := os.Geteuid()
if euid != 0 {
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
}
}
} }

View File

@ -8,6 +8,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@ -150,6 +151,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *TunDevice) GetInterfaceGUIDString() (string, error) { func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil { if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")
@ -169,3 +175,7 @@ func (t *TunDevice) assignAddr() error {
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
} }
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@ -1,6 +1,10 @@
package iface package iface
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -13,4 +17,6 @@ type WGTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
} }

View File

@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
return fmt.Errorf("set interface addr: %w", err) return fmt.Errorf("set interface addr: %w", err)
} }
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
return nil return nil
} }

View File

@ -9,8 +9,11 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -203,6 +206,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
return w.tun.FilteredDevice() return w.tun.FilteredDevice()
} }
// GetWGDevice returns the WireGuard device
func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats(peerKey)
@ -234,3 +242,11 @@ func (w *WGIface) waitUntilRemoved() error {
} }
} }
} }
// GetNet returns the netstack.Net for the netstack device
func (w *WGIface) GetNet() *netstack.Net {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.GetNet()
}

View File

@ -1,112 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

View File

@ -1,35 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}

View File

@ -8,9 +8,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled todo: move these function to cmd layer // IsEnabled todo: move these function to cmd layer
func IsEnabled() bool { func IsEnabled() bool {
return os.Getenv("NB_USE_NETSTACK_MODE") == "true" return os.Getenv(EnvUseNetstackMode) == "true"
} }
func ListenAddr() string { func ListenAddr() string {

View File

@ -1,15 +1,22 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
) )
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address string address net.IP
dnsAddress net.IP
mtu int mtu int
listenAddress string listenAddress string
@ -17,29 +24,48 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
} }
} }
func (t *NetStackTun) Create() (tun.Device, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(t.address)}, []netip.Addr{addr.Unmap()},
[]netip.Addr{}, []netip.Addr{dnsAddr.Unmap()},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
t.tundev = nsTunDev t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil {
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
}
if skipProxy {
return nsTunDev, tunNet, nil
}
dialer := NewNSDialer(tunNet) dialer := NewNSDialer(tunNet)
t.proxy, err = NewSocks5(dialer) t.proxy, err = NewSocks5(dialer)
if err != nil { if err != nil {
_ = t.tundev.Close() _ = t.tundev.Close()
return nil, err return nil, nil, err
} }
go func() { go func() {
@ -49,7 +75,7 @@ func (t *NetStackTun) Create() (tun.Device, error) {
} }
}() }()
return nsTunDev, nil return nsTunDev, tunNet, nil
} }
func (t *NetStackTun) Close() error { func (t *NetStackTun) Close() error {

View File

@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.rollBack(newRulePairs) d.rollBack(newRulePairs)
break break
} }
if len(rules) > 0 { if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair newRulePairs[pairID] = rulePair
} }
@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
var port *firewall.Port var port *firewall.Port
if r.Port != "" { if !portInfoEmpty(r.PortInfo) {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port) value, err := strconv.Atoi(r.Port)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("invalid port, skipping firewall rule") return "", nil, fmt.Errorf("invalid port: %w", err)
} }
port = &firewall.Port{ port = &firewall.Port{
Values: []int{value}, Values: []uint16{uint16(value)},
} }
} }
@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@ -300,6 +305,22 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return ruleID, rules, nil return ruleID, rules, nil
} }
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
if portInfo == nil {
return true
}
switch portInfo.GetPortSelection().(type) {
case *mgmProto.PortInfo_Port:
return portInfo.GetPort() == 0
case *mgmProto.PortInfo_Range_:
r := portInfo.GetRange()
return r == nil || r.Start == 0 || r.End == 0
default:
return true
}
}
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
@ -308,25 +329,12 @@ func (d *DefaultManager) addInRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( return rule, nil
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
@ -337,25 +345,16 @@ func (d *DefaultManager) addOutRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return nil, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
return append(rules, rule...), nil return rule, nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getPeerRuleID() returns unique ID for the rule based on its parameters.
@ -508,7 +507,7 @@ func (d *DefaultManager) squashAcceptRules(
// getRuleGroupingSelector takes all rule properties except IP address to build selector // getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
} }
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
@ -559,14 +558,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
if portInfo.GetPort() != 0 { if portInfo.GetPort() != 0 {
return &firewall.Port{ return &firewall.Port{
Values: []int{int(portInfo.GetPort())}, Values: []uint16{uint16(int(portInfo.GetPort()))},
} }
} }
if portInfo.GetRange() != nil { if portInfo.GetRange() != nil {
return &firewall.Port{ return &firewall.Port{
IsRange: true, IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
} }
} }

View File

@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@ -119,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return return
} }
}) })
@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@ -356,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 { if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return return
} }
} }

View File

@ -8,6 +8,8 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface" iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
} }
// GetDevice mocks base method.
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDevice")
ret0, _ := ret[0].(*device.FilteredDevice)
return ret0
}
// GetDevice indicates an expected call of GetDevice.
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
}
// GetWGDevice mocks base method.
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWGDevice")
ret0, _ := ret[0].(*wgdevice.Device)
return ret0
}
// GetWGDevice indicates an expected call of GetWGDevice.
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
}

View File

@ -2,6 +2,8 @@ package auth
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -11,7 +13,10 @@ import (
"strings" "strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
) )
// HostedGrantType grant type for device flow on Hosted // HostedGrantType grant type for device flow on Hosted
@ -56,6 +61,18 @@ func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*Devi
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
} else {
log.Debug("Using system certificate pool.")
}
httpTransport.TLSClientConfig = &tls.Config{
RootCAs: certPool,
}
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
Transport: httpTransport, Transport: httpTransport,

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"time" "time"
@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -66,6 +68,12 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool
DisableNotifications *bool
DNSLabels domain.List
} }
// Config Configuration type // Config Configuration type
@ -89,6 +97,12 @@ type Config struct {
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
DisableNotifications *bool
DNSLabels domain.List
// SSHKey is a private SSH key in a PEM format // SSHKey is a private SSH key in a PEM format
SSHKey string SSHKey string
@ -455,6 +469,33 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
if *input.BlockLANAccess {
log.Infof("blocking LAN access")
} else {
log.Infof("allowing LAN access")
}
config.BlockLANAccess = *input.BlockLANAccess
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications {
log.Infof("disabling notifications")
} else {
log.Infof("enabling notifications")
}
config.DisableNotifications = input.DisableNotifications
updated = true
}
if config.DisableNotifications == nil {
disabled := true
config.DisableNotifications = &disabled
log.Infof("setting notifications to disabled by default")
updated = true
}
if input.ClientCertKeyPath != "" { if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true updated = true
@ -475,6 +516,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
} }
} }
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
input.DNSLabels.SafeString(),
config.DNSLabels.SafeString())
config.DNSLabels = input.DNSLabels
updated = true
}
return updated, nil return updated, nil
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
@ -31,6 +32,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -59,13 +61,8 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run() error { func (c *ConnectClient) Run(runningChan chan error) error {
return c.run(MobileDependency{}, nil, nil) return c.run(MobileDependency{}, runningChan)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
return c.run(MobileDependency{}, probes, runningChan)
} }
// RunOnAndroid with main logic on mobile system // RunOnAndroid with main logic on mobile system
@ -84,7 +81,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) RunOniOS( func (c *ConnectClient) RunOniOS(
@ -102,18 +99,30 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager, DnsManager: dnsManager,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
} }
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder
if rec != nil {
rec.PublishEvent(
cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM,
"panic occurred",
"The Netbird service panicked. Please restart the service and submit a bug report with the client logs.",
nil,
)
}
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
} }
}() }()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -182,8 +191,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
} }
}() }()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config // connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey) loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
@ -204,8 +213,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.UpdateLocalPeerState(localPeerState) c.statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s", signalURL := fmt.Sprintf("%s://%s",
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()), strings.ToLower(loginResp.GetNetbirdConfig().GetSignal().GetProtocol().String()),
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), loginResp.GetNetbirdConfig().GetSignal().GetUri(),
) )
c.statusRecorder.UpdateSignalAddress(signalURL) c.statusRecorder.UpdateSignalAddress(signalURL)
@ -216,8 +225,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(err) c.statusRecorder.MarkSignalDisconnected(err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Netbird config in hand connect (just a connection, no stream yet) Signal
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey) signalClient, err := connectToSignal(engineCtx, loginResp.GetNetbirdConfig(), myPrivateKey)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
@ -261,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap) c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock() c.engineMutex.Unlock()
@ -316,7 +325,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
} }
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) { func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay() relayCfg := loginResp.GetNetbirdConfig().GetRelay()
if relayCfg == nil { if relayCfg == nil {
return nil, nil return nil, nil
} }
@ -420,6 +429,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableServerRoutes: config.DisableServerRoutes, DisableServerRoutes: config.DisableServerRoutes,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@ -443,7 +454,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
} }
// connectToSignal creates Signal Service client and established a connection // connectToSignal creates Signal Service client and established a connection
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
var sigTLSEnabled bool var sigTLSEnabled bool
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS { if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
sigTLSEnabled = true sigTLSEnabled = true
@ -460,8 +471,8 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc) // loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()
if err != nil { if err != nil {
@ -469,7 +480,16 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
} }
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey) sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {
return nil, err return nil, err
} }

111
client/internal/dns.go Normal file
View File

@ -0,0 +1,111 @@
package internal
import (
"fmt"
"net"
"slices"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
return nbdns.SimpleRecord{}, false
}
if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false
}
ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
}, true
}
// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()
// round up to nearest byte
octetsToUse := (maskOnes + 7) / 8
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
}
reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
}
// zoneExists checks if a zone with the given name already exists in the configuration
func zoneExists(config *nbdns.Config, zoneName string) bool {
for _, zone := range config.CustomZones {
if zone.Domain == zoneName {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return true
}
}
return false
}
// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
}
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord)
}
}
}
return records
}
// addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(ipNet)
if err != nil {
log.Warn(err)
return
}
if zoneExists(config, zoneName) {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return
}
records := collectPTRRecords(config, ipNet)
reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
}
config.CustomZones = append(config.CustomZones, reverseZone)
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
}

View File

@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
} }
} }
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") return ErrRouteAllWithoutNameserverGroup
} }
if !backupFileExist { if !backupFileExist {
@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error {
return f.restore() return f.restore()
} }
func (f *fileConfigurator) string() string {
return "file"
}
func (f *fileConfigurator) backup() error { func (f *fileConfigurator) backup() error {
stats, err := os.Stat(defaultResolvConfPath) stats, err := os.Stat(defaultResolvConfPath)
if err != nil { if err != nil {

View File

@ -12,7 +12,7 @@ import (
const ( const (
PriorityDNSRoute = 100 PriorityDNSRoute = 100
PriorityMatchDomain = 50 PriorityMatchDomain = 50
PriorityDefault = 0 PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@ -26,7 +26,6 @@ type HandlerEntry struct {
Pattern string Pattern string
OrigPattern string OrigPattern string
IsWildcard bool IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool MatchSubdomains bool
} }
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break break
} }
@ -101,21 +97,33 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
Pattern: pattern, Pattern: pattern,
OrigPattern: origPattern, OrigPattern: origPattern,
IsWildcard: isWildcard, IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains, MatchSubdomains: matchSubdomains,
} }
// Insert handler in priority order pos := c.findHandlerPosition(entry)
pos := 0 c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
for i, h := range c.handlers { for i, h := range c.handlers {
if h.Priority < priority { // prio first
pos = i if h.Priority < newEntry.Priority {
break return i
}
// domain specificity next
if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".")
existingDots := strings.Count(h.Pattern, ".")
if newDots > existingDots {
return i
}
} }
pos = i + 1
} }
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) // add at end
return len(c.handlers)
} }
// RemoveHandler removes a handler for the given pattern and priority // RemoveHandler removes a handler for the given pattern and priority
@ -129,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return return
} }
@ -167,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
} }
} }
@ -193,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
if !matched { if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue continue
} }
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{ chainWriter := &ResponseWriterChain{
ResponseWriter: w, ResponseWriter: w,

View File

@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
dnsRouteHandler := &nbdns.MockHandler{} dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
pattern = "*." + tt.handlerDomain[2:] pattern = "*." + tt.handlerDomain[2:]
} }
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
} }
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
} }
// Create and execute request // Create and execute request
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3 := &nbdns.MockHandler{} handler3 := &nbdns.MockHandler{}
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
if op.action == "add" { if op.action == "add" {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
handlers[op.priority] = handler handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil) chain.AddHandler(op.pattern, handler, op.priority)
} else { } else {
chain.RemoveHandler(op.pattern, op.priority) chain.RemoveHandler(op.pattern, op.priority)
} }
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
handler = mockHandler handler = mockHandler
} }
chain.AddHandler(pattern, handler, h.priority, nil) chain.AddHandler(pattern, handler, h.priority)
} }
// Execute request // Execute request
@ -677,3 +677,156 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
}) })
} }
} }
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
tests := []struct {
name string
scenario string
ops []struct {
action string
pattern string
priority int
subdomain bool
}
query string
expectedMatch string
}{
{
name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "more specific domain matches first, both match subdomains",
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "maintain specificity order after removal",
scenario: "after removing most specific, should fall back to less specific",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "test.sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "priority overrides specificity",
scenario: "less specific domain with higher priority should match first",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
},
query: "sub.example.com.",
expectedMatch: "example.com.",
},
{
name: "equal priority respects specificity",
scenario: "with equal priority, more specific domain should match",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "specific matches before wildcard",
scenario: "specific domain should match before wildcard at same priority",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops {
if op.action == "add" {
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
handlers[op.pattern] = handler
chain.AddHandler(op.pattern, handler, op.priority)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
}
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup handler expectations
for pattern, handler := range handlers {
if pattern == tt.expectedMatch {
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetReply(r)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
}
chain.ServeDNS(w, r)
for pattern, handler := range handlers {
if pattern == tt.expectedMatch {
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
} else {
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
}
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More