mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-02 17:01:20 +01:00
Merge branch 'main' into routes-get-account-refactoring
# Conflicts: # .github/workflows/golang-test-linux.yml # go.mod # go.sum # management/server/account.go # management/server/account_test.go # management/server/http/handler.go # management/server/http/handlers/peers/peers_handler.go # management/server/http/middleware/auth_middleware.go # management/server/http/middleware/auth_middleware_test.go # management/server/http/testing/benchmarks/peers_handler_benchmark_test.go # management/server/http/testing/benchmarks/users_handler_benchmark_test.go # management/server/integrated_validator.go # management/server/mock_server/account_mock.go # management/server/peer.go # management/server/status/error.go # management/server/store/sql_store.go # management/server/store/sql_store_test.go # management/server/user.go
This commit is contained in:
commit
e326cc7412
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: Test Code Darwin
|
||||
name: "Darwin"
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -12,9 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['sqlite']
|
||||
name: "Client / Unit"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@ -44,4 +42,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
8
.github/workflows/golang-test-freebsd.yml
vendored
8
.github/workflows/golang-test-freebsd.yml
vendored
@ -1,5 +1,4 @@
|
||||
|
||||
name: Test Code FreeBSD
|
||||
name: "FreeBSD"
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -13,6 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@ -24,7 +24,7 @@ jobs:
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y go
|
||||
pkg install -y go pkgconf xorg
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
|
162
.github/workflows/golang-test-linux.yml
vendored
162
.github/workflows/golang-test-linux.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: Test Code Linux
|
||||
name: Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -12,11 +12,21 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-cache:
|
||||
name: "Build Cache"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
management:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@ -38,7 +48,6 @@ jobs:
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@ -89,6 +98,7 @@ jobs:
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@ -134,9 +144,116 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@ -194,10 +311,17 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 10m ./management/...
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@ -254,10 +378,17 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@ -312,12 +443,21 @@ jobs:
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@ -363,9 +503,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 10m ./management/...
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
5
.github/workflows/golang-test-windows.yml
vendored
5
.github/workflows/golang-test-windows.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: Test Code Windows
|
||||
name: "Windows"
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -14,6 +14,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -65,7 +66,7 @@ jobs:
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
11
.github/workflows/golangci-lint.yml
vendored
11
.github/workflows/golangci-lint.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: golangci-lint
|
||||
name: Lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
@ -27,7 +27,14 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||
name: lint
|
||||
include:
|
||||
- os: macos-latest
|
||||
display_name: Darwin
|
||||
- os: windows-latest
|
||||
display_name: Windows
|
||||
- os: ubuntu-latest
|
||||
display_name: Linux
|
||||
name: ${{ matrix.display_name }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
@ -1,4 +1,4 @@
|
||||
name: Mobile build validation
|
||||
name: Mobile
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -12,6 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
android_build:
|
||||
name: "Android / Build"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@ -47,6 +48,7 @@ jobs:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ios_build:
|
||||
name: "iOS / Build"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@ -9,10 +9,10 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
vendor/
|
||||
|
@ -103,7 +103,7 @@ linters:
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
|
@ -50,10 +50,12 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- deb
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
- src: client/ui/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@ -67,10 +69,12 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- rpm
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
- src: client/ui/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
2
AUTHORS
2
AUTHORS
@ -1,3 +1,3 @@
|
||||
Mikhail Bragin (https://github.com/braginini)
|
||||
Maycon Santos (https://github.com/mlsmaycon)
|
||||
Wiretrustee UG (haftungsbeschränkt)
|
||||
NetBird GmbH
|
||||
|
@ -3,10 +3,10 @@
|
||||
We are incredibly thankful for the contributions we receive from the community.
|
||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||
as BSD-3 while allowing Wiretrustee to build a sustainable business.
|
||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||
|
||||
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables Wiretrustee to safely commercialize our products
|
||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables NetBird to safely commercialize our products
|
||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||
ability to use the project in their own projects or businesses, to republish modified
|
||||
source, or to completely fork the project.
|
||||
@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
|
||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||
carefully review all the terms of the actual CLA before agreeing.
|
||||
|
||||
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
|
||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||
in commercial products.
|
||||
</li>
|
||||
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||
license to use that patent including within commercial products. You also agree that you
|
||||
have permission to grant this license.
|
||||
</li>
|
||||
@ -45,7 +45,7 @@ more.
|
||||
# Why require a CLA?
|
||||
|
||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
|
||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||
products.
|
||||
|
||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||
@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
|
||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||
require you to sign again.
|
||||
|
||||
# Legal Terms and Agreement
|
||||
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
|
||||
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||
your own Contributions for any other purpose.
|
||||
|
||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
|
||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||
You reserve all right, title, and interest in and to Your Contributions.
|
||||
|
||||
1. Definitions.
|
||||
|
||||
```
|
||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
|
||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||
entities that control, are controlled by, or are under common control with that entity are considered
|
||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||
@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
```
|
||||
```
|
||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
|
||||
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
|
||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
```
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
|
||||
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
|
||||
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||
@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
|
||||
with Wiretrustee.
|
||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||
with NetBird.
|
||||
|
||||
|
||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||
@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
|
||||
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
|
4
LICENSE
4
LICENSE
@ -1,6 +1,6 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
|
||||
Copyright (c) 2022 NetBird GmbH & AUTHORS
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
@ -1,4 +1,9 @@
|
||||
<div align="center">
|
||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||
Webinar: How to Achieve Zero Trust Access to Kubernetes — Effortlessly
|
||||
</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM alpine:3.21.0
|
||||
FROM alpine:3.21.3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
@ -9,6 +9,7 @@ USER netbird:netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV NB_USE_NETSTACK_MODE=true
|
||||
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||
ENV NB_CONFIG=config.json
|
||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||
ENV NB_DISABLE_DNS=true
|
||||
|
@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
var dnsLabelsReq []string
|
||||
if dnsLabelsValidated != nil {
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
|
@ -38,6 +38,7 @@ const (
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -73,6 +74,7 @@ var (
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Debug(err)
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
|
@ -2,107 +2,20 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Connected int `json:"connected" yaml:"connected"`
|
||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Available int `json:"available" yaml:"available"`
|
||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
Local string `json:"local" yaml:"local"`
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
@ -173,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -214,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
@ -251,175 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
Error: managementState.Error,
|
||||
}
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
Error: signalState.Error,
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
if anonymizeFlag {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
anonymizeOverview(anonymizer, &overview)
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
var relayStateDetail []relayStateOutputDetail
|
||||
|
||||
var relaysAvailable int
|
||||
for _, relay := range relays {
|
||||
available := relay.GetAvailable()
|
||||
relayStateDetail = append(relayStateDetail,
|
||||
relayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
},
|
||||
)
|
||||
|
||||
if available {
|
||||
relaysAvailable++
|
||||
}
|
||||
}
|
||||
|
||||
return relayStateOutput{
|
||||
Total: len(relays),
|
||||
Available: relaysAvailable,
|
||||
Details: relayStateDetail,
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
peersConnected := 0
|
||||
for _, pbPeerState := range peers {
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected++
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
peerState := peerStateDetailOutput{
|
||||
IP: pbPeerState.GetIP(),
|
||||
PubKey: pbPeerState.GetPubKey(),
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal,
|
||||
ConnType: connType,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
RelayAddress: relayServerAddress,
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetNetworks(),
|
||||
Networks: pbPeerState.GetNetworks(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
}
|
||||
|
||||
sortPeersByIP(peersStateDetail)
|
||||
|
||||
peersOverview := peersStateOutput{
|
||||
Total: len(peersStateDetail),
|
||||
Connected: peersConnected,
|
||||
Details: peersStateDetail,
|
||||
}
|
||||
return peersOverview
|
||||
}
|
||||
|
||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
||||
if len(peersStateDetail) > 0 {
|
||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
||||
return iAddr.Compare(jAddr) == -1
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
@ -427,452 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
|
||||
}
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
}
|
||||
|
||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
if goarch == "arm" {
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
"%s\n"+
|
||||
"%s",
|
||||
parsedPeersString,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
|
||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
|
||||
for _, peerState := range peers.Details {
|
||||
|
||||
localICE := "-"
|
||||
if peerState.IceCandidateType.Local != "" {
|
||||
localICE = peerState.IceCandidateType.Local
|
||||
}
|
||||
|
||||
remoteICE := "-"
|
||||
if peerState.IceCandidateType.Remote != "" {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
localICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Local != "" {
|
||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||
}
|
||||
|
||||
remoteICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
} else {
|
||||
if rosenpassPermissive {
|
||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||
} else {
|
||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||
}
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(peerState.Networks) > 0 {
|
||||
sort.Strings(peerState.Networks)
|
||||
networks = strings.Join(peerState.Networks, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
" Public key: %s\n"+
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Relay server address: %s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n"+
|
||||
" Networks: %s\n"+
|
||||
" Latency: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
localICE,
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
peerState.RelayAddress,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peerState.Latency.String(),
|
||||
)
|
||||
|
||||
peersString += peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
if lowerStatusFilter == "disconnected" && isConnected {
|
||||
statusEval = true
|
||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
||||
statusEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
_, ok := ipsFilterMap[peerState.IP]
|
||||
if !ok {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
||||
func timeAgo(t time.Time) string {
|
||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
||||
return "-"
|
||||
}
|
||||
duration := time.Since(t)
|
||||
switch {
|
||||
case duration < time.Second:
|
||||
return "Now"
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Seconds())
|
||||
if seconds == 1 {
|
||||
return "1 second ago"
|
||||
}
|
||||
return fmt.Sprintf("%d seconds ago", seconds)
|
||||
case duration < time.Hour:
|
||||
minutes := int(duration.Minutes())
|
||||
seconds := int(duration.Seconds()) % 60
|
||||
if minutes == 1 {
|
||||
if seconds == 1 {
|
||||
return "1 minute, 1 second ago"
|
||||
} else if seconds > 0 {
|
||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
||||
}
|
||||
return "1 minute ago"
|
||||
}
|
||||
if seconds > 0 {
|
||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d minutes ago", minutes)
|
||||
case duration < 24*time.Hour:
|
||||
hours := int(duration.Hours())
|
||||
minutes := int(duration.Minutes()) % 60
|
||||
if hours == 1 {
|
||||
if minutes == 1 {
|
||||
return "1 hour, 1 minute ago"
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
||||
}
|
||||
return "1 hour ago"
|
||||
}
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d hours ago", hours)
|
||||
}
|
||||
|
||||
days := int(duration.Hours()) / 24
|
||||
hours := int(duration.Hours()) % 24
|
||||
if days == 1 {
|
||||
if hours == 1 {
|
||||
return "1 day, 1 hour ago"
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
||||
}
|
||||
return "1 day ago"
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%d days ago", days)
|
||||
}
|
||||
|
||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
||||
}
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
anonymizePeerDetail(a, &peer)
|
||||
overview.Peers.Details[i] = peer
|
||||
}
|
||||
|
||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
||||
|
||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
||||
for i, detail := range overview.Relays.Details {
|
||||
detail.URI = a.AnonymizeURI(detail.URI)
|
||||
detail.Error = a.AnonymizeString(detail.Error)
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
for j, ns := range nsGroup.Servers {
|
||||
host, port, err := net.SplitHostPort(ns)
|
||||
if err == nil {
|
||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
@ -1,597 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loc, err := time.LoadLocation("UTC")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
time.Local = loc
|
||||
}
|
||||
|
||||
var resp = &proto.StatusResponse{
|
||||
Status: "Connected",
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
RemoteIceCandidateEndpoint: "",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: []*proto.RelayState{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
}
|
||||
|
||||
var overview = statusOutputOverview{
|
||||
Peers: peersStateOutput{
|
||||
Total: 2,
|
||||
Connected: 2,
|
||||
Details: []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
FQDN: "peer-1.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
FQDN: "peer-2.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "10.0.0.1:10001",
|
||||
Remote: "10.0.10.1:10002",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
},
|
||||
},
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: "0.14.1",
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: relayStateOutput{
|
||||
Total: 2,
|
||||
Available: 1,
|
||||
Details: []relayStateOutputDetail{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := convertToStatusOutputOverview(resp)
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
|
||||
func TestSortingOfPeers(t *testing.T) {
|
||||
peers := []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.104",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.105",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.103",
|
||||
},
|
||||
}
|
||||
|
||||
sortPeersByIP(peers)
|
||||
|
||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
{
|
||||
"peers": {
|
||||
"total": 2,
|
||||
"connected": 2,
|
||||
"details": [
|
||||
{
|
||||
"fqdn": "peer-1.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.101",
|
||||
"publicKey": "Pubkey1",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.102",
|
||||
"publicKey": "Pubkey2",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": null,
|
||||
"networks": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"cliVersion": "development",
|
||||
"daemonVersion": "0.14.1",
|
||||
"management": {
|
||||
"url": "my-awesome-management.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"signal": {
|
||||
"url": "my-awesome-signal.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"relays": {
|
||||
"total": 2,
|
||||
"available": 1,
|
||||
"details": [
|
||||
{
|
||||
"uri": "stun:my-awesome-stun.com:3478",
|
||||
"available": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
"available": false,
|
||||
"error": "context: deadline exceeded"
|
||||
}
|
||||
]
|
||||
},
|
||||
"netbirdIp": "192.168.178.100/16",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
var expectedJSON bytes.Buffer
|
||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||
|
||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
total: 2
|
||||
connected: 2
|
||||
details:
|
||||
- fqdn: peer-1.awesome-domain.com
|
||||
netbirdIp: 192.168.178.101
|
||||
publicKey: Pubkey1
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
networks:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
networks: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
url: my-awesome-management.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
signal:
|
||||
url: my-awesome-signal.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
relays:
|
||||
total: 2
|
||||
available: 1
|
||||
details:
|
||||
- uri: stun:my-awesome-stun.com:3478
|
||||
available: true
|
||||
error: ""
|
||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||
available: false
|
||||
error: 'context: deadline exceeded'
|
||||
netbirdIp: 192.168.178.100/16
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
networks:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
// Calculate time ago based on the fixture dates
|
||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
Public key: Pubkey1
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
Networks: 10.1.0.0/24
|
||||
Latency: 10ms
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
Public key: Pubkey2
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Networks: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
Daemon version: 0.14.1
|
||||
CLI version: %s
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
func TestParsingOfIP(t *testing.T) {
|
||||
InterfaceIP := "192.168.178.123/16"
|
||||
|
||||
@ -599,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
|
||||
func TestTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
expected string
|
||||
}{
|
||||
{"Now", now, "Now"},
|
||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
||||
{"Zero time", time.Time{}, "-"},
|
||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := timeAgo(tc.input)
|
||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@ -0,0 +1,137 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var traceCmd = &cobra.Command{
|
||||
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||
Short: "Trace a packet through the firewall",
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(traceCmd)
|
||||
|
||||
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||
}
|
||||
|
||||
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
direction := strings.ToLower(args[0])
|
||||
if direction != "in" && direction != "out" {
|
||||
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||
}
|
||||
|
||||
protocol := cmd.Flag("protocol").Value.String()
|
||||
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||
}
|
||||
|
||||
sport, err := cmd.Flags().GetUint16("sport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid source port: %v", err)
|
||||
}
|
||||
dport, err := cmd.Flags().GetUint16("dport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination port: %v", err)
|
||||
}
|
||||
|
||||
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||
if protocol != "icmp" {
|
||||
if sport == 0 {
|
||||
sport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
if dport == 0 {
|
||||
dport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
}
|
||||
|
||||
var tcpFlags *proto.TCPFlags
|
||||
if protocol == "tcp" {
|
||||
syn, _ := cmd.Flags().GetBool("syn")
|
||||
ack, _ := cmd.Flags().GetBool("ack")
|
||||
fin, _ := cmd.Flags().GetBool("fin")
|
||||
rst, _ := cmd.Flags().GetBool("rst")
|
||||
psh, _ := cmd.Flags().GetBool("psh")
|
||||
urg, _ := cmd.Flags().GetBool("urg")
|
||||
|
||||
tcpFlags = &proto.TCPFlags{
|
||||
Syn: syn,
|
||||
Ack: ack,
|
||||
Fin: fin,
|
||||
Rst: rst,
|
||||
Psh: psh,
|
||||
Urg: urg,
|
||||
}
|
||||
}
|
||||
|
||||
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||
SourceIp: args[1],
|
||||
DestinationIp: args[2],
|
||||
Protocol: protocol,
|
||||
SourcePort: uint32(sport),
|
||||
DestinationPort: uint32(dport),
|
||||
Direction: direction,
|
||||
TcpFlags: tcpFlags,
|
||||
IcmpType: &icmpType,
|
||||
IcmpCode: &icmpCode,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||
} else {
|
||||
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||
}
|
||||
}
|
||||
|
||||
disposition := map[bool]string{
|
||||
true: "\033[32mALLOWED\033[0m", // Green
|
||||
false: "\033[31mDENIED\033[0m", // Red
|
||||
}[resp.FinalDisposition]
|
||||
|
||||
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||
}
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -29,9 +30,16 @@ const (
|
||||
interfaceInputType
|
||||
)
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
upCmd = &cobra.Command{
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "install, login and start Netbird client",
|
||||
RunE: upFunc,
|
||||
@ -48,6 +56,15 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@ -66,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if hostName != "" {
|
||||
@ -97,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
@ -160,6 +183,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -185,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run()
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
@ -235,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@ -290,6 +319,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
@ -421,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateDnsLabels(labels []string) (domain.List, error) {
|
||||
var (
|
||||
domains domain.List
|
||||
err error
|
||||
)
|
||||
|
||||
if len(labels) == 0 {
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
domains, err = domain.ValidateDomains(labels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func isValidAddrPort(input string) bool {
|
||||
if input == "" {
|
||||
return true
|
||||
|
167
client/embed/doc.go
Normal file
167
client/embed/doc.go
Normal file
@ -0,0 +1,167 @@
|
||||
// Package embed provides a way to embed the NetBird client directly
|
||||
// into Go programs without requiring a separate NetBird client installation.
|
||||
package embed
|
||||
|
||||
// Basic Usage:
|
||||
//
|
||||
// client, err := embed.New(embed.Options{
|
||||
// DeviceName: "my-service",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Server Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "log"
|
||||
// "net/http"
|
||||
// "os"
|
||||
// "os/signal"
|
||||
// "syscall"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-server",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP server
|
||||
// mux := http.NewServeMux()
|
||||
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
|
||||
// fmt.Fprintf(w, "Hello from netbird!")
|
||||
// })
|
||||
//
|
||||
// // Listen on netbird network
|
||||
// l, err := client.ListenTCP(":8080")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// server := &http.Server{Handler: mux}
|
||||
// go func() {
|
||||
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
|
||||
// log.Printf("HTTP server error: %v", err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// log.Printf("HTTP server listening on netbird network port 8080")
|
||||
//
|
||||
// // Handle shutdown
|
||||
// stop := make(chan os.Signal, 1)
|
||||
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
// <-stop
|
||||
//
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
// log.Printf("HTTP shutdown error: %v", err)
|
||||
// }
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Client Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "io"
|
||||
// "log"
|
||||
// "os"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-client",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP client that uses netbird network
|
||||
// httpClient := client.NewHTTPClient()
|
||||
// httpClient.Timeout = 10 * time.Second
|
||||
//
|
||||
// // Make request to server in netbird network
|
||||
// target := os.Getenv("NB_TARGET")
|
||||
// resp, err := httpClient.Get(target)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// // Read and print response
|
||||
// body, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// fmt.Printf("Response from server: %s\n", string(body))
|
||||
//
|
||||
// // Clean shutdown
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The package provides several methods for network operations:
|
||||
// - Dial: Creates outbound connections
|
||||
// - ListenTCP: Creates TCP listeners
|
||||
// - ListenUDP: Creates UDP listeners
|
||||
//
|
||||
// By default, the embed package uses userspace networking mode, which doesn't
|
||||
// require root/admin privileges. For production deployments, consider setting
|
||||
// appropriate config and state paths for persistence.
|
296
client/embed/embed.go
Normal file
296
client/embed/embed.go
Normal file
@ -0,0 +1,296 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
|
||||
// Client manages a netbird embedded client instance
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *internal.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
// LogLevel sets the logging level (defaults to info if empty)
|
||||
LogLevel string
|
||||
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
|
||||
NoUserspace bool
|
||||
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
|
||||
ConfigPath string
|
||||
// StatePath is the path to the netbird state file
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client
|
||||
func New(opts Options) (*Client, error) {
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
|
||||
if opts.LogLevel != "" {
|
||||
level, err := logrus.ParseLevel(opts.LogLevel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
|
||||
if !opts.NoUserspace {
|
||||
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.StatePath != "" {
|
||||
// TODO: Disable state if path not provided
|
||||
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *internal.Config
|
||||
var err error
|
||||
input := internal.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = internal.UpdateOrCreateConfig(input)
|
||||
} else {
|
||||
config, err = internal.CreateInMemoryConfig(input)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
run <- err
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-run:
|
||||
if err != nil {
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||
}
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the client.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
|
||||
func (c *Client) Stop(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connect == nil {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.ListenUDP(udpAddr)
|
||||
}
|
||||
|
||||
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) NewHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
DialContext: c.Dial,
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet, addr, nil
|
||||
}
|
@ -14,13 +14,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@ -10,4 +12,6 @@ type IFaceMapper interface {
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ package iptables
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
@ -19,8 +19,7 @@ const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
chain := chainNameInputRules
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
@ -145,16 +138,22 @@ func (m *aclManager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
|
||||
if r.mangleSpecs != nil {
|
||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
@ -329,8 +307,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
|
||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
@ -346,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
@ -390,42 +359,26 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
return append(specs, "-j", actionToStr(action))
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
@ -435,15 +388,15 @@ func actionToStr(action firewall.Action) string {
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
case sPort != nil && dPort != nil:
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
case sPort != nil:
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
case dPort != nil:
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
|
@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
@ -215,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
@ -590,10 +599,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(p)
|
||||
portList[i] = strconv.Itoa(int(p))
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
dPort: &firewall.Port{Values: []uint16{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
|
@ -5,9 +5,10 @@ type Rule struct {
|
||||
ruleID string
|
||||
ipsetName string
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
chain string
|
||||
specs []string
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
@ -69,7 +69,6 @@ type Manager interface {
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@ -100,6 +99,12 @@ type Manager interface {
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
SetLogLevel(log.Level)
|
||||
|
||||
EnableRouting() error
|
||||
|
||||
DisableRouting() error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
@ -30,7 +30,7 @@ type Port struct {
|
||||
IsRange bool
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
Values []uint16
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
@ -40,7 +40,11 @@ func (p *Port) String() string {
|
||||
if ports != "" {
|
||||
ports += ","
|
||||
}
|
||||
ports += strconv.Itoa(port)
|
||||
ports += strconv.Itoa(int(port))
|
||||
}
|
||||
if p.IsRange {
|
||||
ports = "range:" + ports
|
||||
}
|
||||
|
||||
return ports
|
||||
}
|
||||
|
@ -2,9 +2,9 @@ package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -22,8 +22,7 @@ import (
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameOutputRules = "netbird-acl-output-rules"
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
@ -45,9 +44,9 @@ type AclManager struct {
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@ -89,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@ -104,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -121,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if r.nftSet == nil {
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||
if !ok {
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||
if err != nil {
|
||||
@ -156,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -214,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
expOut := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainOutputRules,
|
||||
Position: 0,
|
||||
Exprs: expOut,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
@ -260,25 +239,33 @@ func (m *AclManager) Flush() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
|
||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
r.nftSet,
|
||||
r.ruleID,
|
||||
ip,
|
||||
nftRule: r.nftRule,
|
||||
mangleRule: r.mangleRule,
|
||||
nftSet: r.nftSet,
|
||||
ruleID: r.ruleID,
|
||||
ip: ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -310,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
addrOffset += 4 // is ipv4 address length
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
@ -342,73 +326,100 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 0,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
expressions = append(expressions, applyPort(sPort, true)...)
|
||||
expressions = append(expressions, applyPort(dPort, false)...)
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
mainExpressions := slices.Clone(expressions)
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
|
||||
var chain *nftables.Chain
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
chain = m.chainInputRules
|
||||
} else {
|
||||
chain = m.chainOutputRules
|
||||
}
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
Exprs: mainExpressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
nftRule: nftRule,
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
m.rules[ruleId] = rule
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if m.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
return m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
@ -419,15 +430,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// chainNameOutputRules
|
||||
chain = m.createChain(chainNameOutputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
m.chainOutputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
@ -461,7 +463,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
@ -469,8 +471,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
|
||||
m.addPreroutingRule(preroutingChain)
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
@ -480,43 +480,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: preroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
@ -532,8 +495,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@ -680,6 +642,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
@ -696,7 +659,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if m.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
@ -713,22 +676,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
split := bytes.Split(rule.UserData, []byte(" "))
|
||||
r, ok := m.rules[string(split[0])]
|
||||
if ok {
|
||||
*r.nftRule = *rule
|
||||
if mangle {
|
||||
*r.mangleRule = *rule
|
||||
} else {
|
||||
*r.nftRule = *rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(
|
||||
ip net.IP,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
@ -744,12 +704,6 @@ func generatePeerRuleId(
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func encodePort(port firewall.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
|
@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@ -312,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
|
@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@ -116,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
@ -329,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
for i := range got {
|
||||
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||
_, wantIsCounter := want[i].(*expr.Counter)
|
||||
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||
continue
|
||||
}
|
||||
|
||||
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||
}
|
||||
}
|
||||
|
@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
rule = r.conn.AddRule(rule)
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// TODO: Insert after the established rule
|
||||
rule = r.conn.InsertRule(rule)
|
||||
} else {
|
||||
rule = r.conn.AddRule(rule)
|
||||
}
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
@ -956,12 +962,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
@ -980,7 +986,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||
Data: binaryutil.BigEndian.PutUint16(p),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
dPort: &firewall.Port{Values: []uint16{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
|
@ -8,10 +8,11 @@ import (
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
nftRule *nftables.Rule
|
||||
mangleRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
@ -3,6 +3,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
|
@ -1,9 +1,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
|
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@ -0,0 +1,16 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
@ -10,12 +10,11 @@ import (
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
established atomic.Bool
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (b *BaseConnTrack) IsEstablished() bool {
|
||||
return b.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||
b.established.Store(state)
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
|
@ -3,8 +3,14 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
||||
})
|
||||
|
||||
}
|
||||
func BenchmarkAtomicOperations(b *testing.B) {
|
||||
conn := &BaseConnTrack{}
|
||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.IsEstablished()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SetEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.SetEstablished(i%2 == 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.GetLastSeen()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@ -42,12 +45,13 @@ type ICMPTracker struct {
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
switch icmpType {
|
||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||
return true
|
||||
case uint8(layers.ICMPv4TypeEchoReply):
|
||||
// continue processing
|
||||
default:
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
@ -5,7 +5,10 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -61,12 +64,24 @@ type TCPConnKey struct {
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
@ -76,8 +91,9 @@ type TCPTracker struct {
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
return
|
||||
}
|
||||
|
||||
@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
type UDPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@ -29,12 +32,13 @@ type UDPTracker struct {
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.True(t, conn.IsEstablished())
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@ -0,0 +1,81 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (e *endpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityNone
|
||||
}
|
||||
|
||||
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) Wait() {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@ -0,0 +1,166 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultReceiveWindow = 32768
|
||||
defaultMaxInFlight = 1024
|
||||
iosReceiveWindow = 16384
|
||||
iosMaxInFlight = 256
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
maxInFlight := defaultMaxInFlight
|
||||
if runtime.GOOS == "ios" {
|
||||
receiveWindow = iosReceiveWindow
|
||||
maxInFlight = iosMaxInFlight
|
||||
}
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
if f.udpForwarder != nil {
|
||||
f.udpForwarder.Stop()
|
||||
}
|
||||
|
||||
f.stack.Close()
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@ -0,0 +1,109 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
}
|
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@ -0,0 +1,90 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
@ -0,0 +1,284 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
udpTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
id stack.TransportEndpointID
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
go f.cleanup()
|
||||
return f
|
||||
}
|
||||
|
||||
// Stop stops the UDP forwarder and all active connections
|
||||
func (f *udpForwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
delete(f.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
var idleConns []idleConn
|
||||
|
||||
f.RLock()
|
||||
for id, conn := range f.conns {
|
||||
if conn.getIdleDuration() > udpTimeout {
|
||||
idleConns = append(idleConns, idleConn{id, conn})
|
||||
}
|
||||
}
|
||||
f.RUnlock()
|
||||
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
|
||||
f.Lock()
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
|
||||
f.udpForwarder.RLock()
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
// Double-check no connection was created while we were setting up
|
||||
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
134
client/firewall/uspfilter/localip.go
Normal file
134
client/firewall/uspfilter/localip.go
Normal file
@ -0,0 +1,134 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
return &localIPManager{}
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
ipStr := ip.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
270
client/firewall/uspfilter/localip_test.go
Normal file
270
client/firewall/uspfilter/localip_test.go
Normal file
@ -0,0 +1,270 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
|
||||
err := manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.IsLocalIP(tt.testIP)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
mock := &IFaceMock{}
|
||||
|
||||
// Get actual local interfaces
|
||||
interfaces, err := net.Interfaces()
|
||||
require.NoError(t, err)
|
||||
|
||||
var tests []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}
|
||||
|
||||
// Add all local interface IPs to test cases
|
||||
for _, iface := range interfaces {
|
||||
addrs, err := iface.Addrs()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip4.String(),
|
||||
expected: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add some external IPs as negative test cases
|
||||
externalIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"208.67.222.222",
|
||||
}
|
||||
for _, ip := range externalIPs {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip,
|
||||
expected: false,
|
||||
})
|
||||
}
|
||||
|
||||
require.NotEmpty(t, tests, "No test cases generated")
|
||||
|
||||
err = manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
196
client/firewall/uspfilter/log/log.go
Normal file
196
client/firewall/uspfilter/log/log.go
Normal file
@ -0,0 +1,196 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
LevelPanic Level = iota
|
||||
LevelFatal
|
||||
LevelError
|
||||
LevelWarn
|
||||
LevelInfo
|
||||
LevelDebug
|
||||
LevelTrace
|
||||
)
|
||||
|
||||
var levelStrings = map[Level]string{
|
||||
LevelPanic: "PANC",
|
||||
LevelFatal: "FATL",
|
||||
LevelError: "ERRO",
|
||||
LevelWarn: "WARN",
|
||||
LevelInfo: "INFO",
|
||||
LevelDebug: "DEBG",
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.worker()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the logger
|
||||
func (l *Logger) Stop(ctx context.Context) error {
|
||||
done := make(chan struct{})
|
||||
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.shutdown)
|
||||
})
|
||||
|
||||
go func() {
|
||||
l.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@ -0,0 +1,85 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
@ -2,22 +2,22 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
comment string
|
||||
|
||||
@ -25,6 +25,21 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
390
client/firewall/uspfilter/tracer.go
Normal file
390
client/firewall/uspfilter/tracer.go
Normal file
@ -0,0 +1,390 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
)
|
||||
|
||||
type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
}[s]
|
||||
}
|
||||
|
||||
type ForwarderAction struct {
|
||||
Action string
|
||||
RemoteAddr string
|
||||
Error error
|
||||
}
|
||||
|
||||
type TraceResult struct {
|
||||
Timestamp time.Time
|
||||
Stage PacketStage
|
||||
Message string
|
||||
Allowed bool
|
||||
ForwarderAction *ForwarderAction
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
Direction fw.RuleDirection
|
||||
Results []TraceResult
|
||||
}
|
||||
|
||||
type TCPState struct {
|
||||
SYN bool
|
||||
ACK bool
|
||||
FIN bool
|
||||
RST bool
|
||||
PSH bool
|
||||
URG bool
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Direction fw.RuleDirection
|
||||
PayloadSize int
|
||||
TCPState *TCPState
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
ForwarderAction: action,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers = append(pktLayers, transportLayer...)
|
||||
|
||||
if p.PayloadSize > 0 {
|
||||
payload := make([]byte, p.PayloadSize)
|
||||
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||
}
|
||||
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ip)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ip)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
Window: 65535,
|
||||
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||
RST: p.TCPState != nil && p.TCPState.RST,
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||
icmp.Id = uint16(1)
|
||||
icmp.Seq = uint16(1)
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
|
||||
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return int(layers.IPProtocolTCP)
|
||||
case fw.ProtocolUDP:
|
||||
return int(layers.IPProtocolUDP)
|
||||
case fw.ProtocolICMP:
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||
packetData, err := builder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build packet: %w", err)
|
||||
}
|
||||
|
||||
return m.TracePacket(packetData, builder.Direction), nil
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
|
||||
// Extract base packet info
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
trace.SourceIP = srcIP
|
||||
trace.DestinationIP = dstIP
|
||||
|
||||
// Determine protocol and ports
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
trace.Protocol = "TCP"
|
||||
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
trace.Protocol = "UDP"
|
||||
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
return m.traceOutbound(packetData, trace)
|
||||
}
|
||||
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
trace.AddResult(StageConntrack, msg, true)
|
||||
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||
return true
|
||||
}
|
||||
trace.AddResult(StageConntrack, msg, false)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
msg := "Matched existing connection state"
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||
flags&conntrack.TCPSyn != 0,
|
||||
flags&conntrack.TCPAck != 0,
|
||||
flags&conntrack.TCPRst != 0,
|
||||
flags&conntrack.TCPFin != 0)
|
||||
case layers.LayerTypeICMPv4:
|
||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
|
||||
if m.netstack {
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
}
|
||||
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
msg := "Allowed by route ACLs"
|
||||
if !allowed {
|
||||
msg = "Blocked by route ACLs"
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||
fwdAction := &ForwarderAction{
|
||||
Action: action,
|
||||
RemoteAddr: remoteAddr,
|
||||
}
|
||||
trace.AddResultWithForwarder(StageForwarding,
|
||||
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||
}
|
||||
return trace
|
||||
}
|
@ -1,11 +1,14 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@ -14,44 +17,83 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
type RouteRules []RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
return -1
|
||||
}
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
})
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
stateful bool
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
netstack bool
|
||||
// indicates whether we forward local traffic to the native stack
|
||||
localForwarding bool
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
logger *nblog.Logger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@ -68,22 +110,44 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.nativeFirewall = nativeFirewall
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@ -99,52 +163,182 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting() error {
|
||||
var disableUspRouting, forceUserspaceRouter bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||
disableUspRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
|
||||
// userspace routed packets are always SNATed to the inbound direction
|
||||
// TODO: implement outbound SNAT
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
@ -156,17 +350,15 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
@ -179,13 +371,8 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) == 1 {
|
||||
r.sPort = uint16(sPort.Values[0])
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) == 1 {
|
||||
r.dPort = uint16(dPort.Values[0])
|
||||
}
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
@ -202,33 +389,64 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
id: ruleID,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.GetRuleID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@ -236,24 +454,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
r, ok := rule.(*PeerRule)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -276,10 +485,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules)
|
||||
return m.dropFilter(packetData)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@ -300,18 +513,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always process UDP hooks
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
// Track other protocols only if stateful mode is enabled
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
@ -319,6 +525,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -380,7 +591,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
@ -400,10 +611,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
// TODO: Disable router if --disable-server-router is set
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@ -416,39 +626,129 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check connection state only if enabled
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
if !m.localForwarding {
|
||||
// pass to virtual tcp/ip stack to be picked up by listeners
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
// don't process this packet further
|
||||
return true
|
||||
}
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
return false
|
||||
}
|
||||
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0, 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
@ -483,7 +783,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
}
|
||||
@ -500,7 +815,24 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
if rulePort == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if rulePort.IsRange {
|
||||
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
|
||||
}
|
||||
|
||||
for _, p := range rulePort.Values {
|
||||
if p == packetPort {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
@ -517,13 +849,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
@ -533,13 +859,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
@ -549,6 +869,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
sourceMatched := false
|
||||
for _, src := range rule.sources {
|
||||
if src.Contains(srcAddr) {
|
||||
sourceMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
@ -560,13 +925,12 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@ -577,14 +941,13 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
@ -596,21 +959,62 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(level log.Level) {
|
||||
if m.logger != nil {
|
||||
m.logger.SetLevel(nblog.Level(level))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.determineRouting()
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,9 +1,12 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -91,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
|
||||
fw.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@ -112,9 +115,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{1024 + i}},
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@ -126,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
|
||||
fw.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, manager.incomingRules)
|
||||
manager.dropFilter(testIn)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -588,9 +591,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -679,9 +682,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -797,9 +800,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
})
|
||||
@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@ -884,9 +887,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
// Add several route rules to simulate real-world scenario
|
||||
rules := []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
port *fw.Port
|
||||
}{
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
port: &fw.Port{Values: []uint16{80, 443}},
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolICMP,
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: fw.ProtocolUDP,
|
||||
port: &fw.Port{Values: []uint16{53}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases that exercise different matching scenarios
|
||||
cases := []struct {
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
dstPort uint16
|
||||
}{
|
||||
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,17 +9,38 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||
if i.GetDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@ -69,12 +90,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -96,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@ -104,38 +124,16 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP("192.168.1.1")
|
||||
proto = fw.ProtocolTCP
|
||||
port = &fw.Port{Values: []int{80}}
|
||||
direction = fw.RuleDirectionIN
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
@ -189,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
@ -217,18 +215,14 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if tt.expDir != addedRule.direction {
|
||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
@ -242,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@ -250,12 +244,11 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -275,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@ -289,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -327,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@ -346,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface)
|
||||
manager, err := Create(iface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@ -392,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@ -400,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
@ -478,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@ -492,12 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@ -509,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@ -518,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@ -639,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@ -710,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if params.Net == nil {
|
||||
var err error
|
||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
||||
params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPMuxDefault{
|
||||
mux := &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||
},
|
||||
},
|
||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
||||
}
|
||||
|
||||
mux.updateLocalAddresses()
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.localAddrsForUnspecified = localAddrsForUnspecified
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
m.updateLocalAddresses()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.localAddrsForUnspecified) > 0 {
|
||||
return m.localAddrsForUnspecified
|
||||
return slices.Clone(m.localAddrsForUnspecified)
|
||||
}
|
||||
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
m.mu.Lock()
|
||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||
m.mu.Unlock()
|
||||
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
return nil, fmt.Errorf("invalid address %s", addr.String())
|
||||
}
|
||||
|
||||
|
@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "utun100"
|
||||
|
@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.NetbirdFwmark
|
||||
}
|
||||
return 0
|
||||
|
@ -3,6 +3,10 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@ -15,4 +19,6 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -63,7 +64,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
@ -130,6 +131,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesToString(routes []string) string {
|
||||
return strings.Join(routes, ";")
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -117,6 +118,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
@ -138,3 +144,7 @@ func (t *TunDevice) assignAddr() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -64,7 +65,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
log.Debug("Attaching to interface")
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
@ -131,3 +132,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -33,8 +35,6 @@ type TunKernelDevice struct {
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
checkUser()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &TunKernelDevice{
|
||||
ctx: ctx,
|
||||
@ -153,6 +153,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Device returns the wireguard device, not applicable for kernel devices
|
||||
func (t *TunKernelDevice) Device() *device.Device {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
return nil
|
||||
}
|
||||
@ -161,3 +166,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunKernelDevice) assignAddr() error {
|
||||
return t.link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
@ -8,10 +8,12 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@ -25,9 +27,11 @@ type TunNetstackDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
nsTun *netstack.NetStackTun
|
||||
nsTun *nbnetstack.NetStackTun
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
@ -43,13 +47,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create netstack tun interface")
|
||||
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
|
||||
tunIface, err := t.nsTun.Create()
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
tunIface, net, err := t.nsTun.Create()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tun device: %s", err)
|
||||
}
|
||||
t.filteredDevice = newDeviceFilter(tunIface)
|
||||
t.net = net
|
||||
|
||||
t.device = device.NewDevice(
|
||||
t.filteredDevice,
|
||||
@ -117,3 +127,12 @@ func (t *TunNetstackDevice) DeviceName() string {
|
||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunNetstackDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||
return t.net
|
||||
}
|
||||
|
@ -4,12 +4,11 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -32,8 +31,6 @@ type USPDevice struct {
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
checkUser()
|
||||
|
||||
return &USPDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@ -128,6 +125,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
@ -135,11 +137,6 @@ func (t *USPDevice) assignAddr() error {
|
||||
return link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func checkUser() {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
euid := os.Geteuid()
|
||||
if euid != 0 {
|
||||
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
|
||||
}
|
||||
}
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@ -150,6 +151,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
if t.nativeTunDevice == nil {
|
||||
return "", fmt.Errorf("interface has not been initialized yet")
|
||||
@ -169,3 +175,7 @@ func (t *TunDevice) assignAddr() error {
|
||||
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@ -13,4 +17,6 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
|
||||
return fmt.Errorf("set interface addr: %w", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -9,8 +9,11 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -203,6 +206,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
||||
return w.tun.FilteredDevice()
|
||||
}
|
||||
|
||||
// GetWGDevice returns the WireGuard device
|
||||
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
@ -234,3 +242,11 @@ func (w *WGIface) waitUntilRemoved() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetNet returns the netstack.Net for the netstack device
|
||||
func (w *WGIface) GetNet() *netstack.Net {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.tun.GetNet()
|
||||
}
|
||||
|
@ -1,112 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
@ -8,9 +8,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||
|
||||
// IsEnabled todo: move these function to cmd layer
|
||||
func IsEnabled() bool {
|
||||
return os.Getenv("NB_USE_NETSTACK_MODE") == "true"
|
||||
return os.Getenv(EnvUseNetstackMode) == "true"
|
||||
}
|
||||
|
||||
func ListenAddr() string {
|
||||
|
@ -1,15 +1,22 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address string
|
||||
address net.IP
|
||||
dnsAddress net.IP
|
||||
mtu int
|
||||
listenAddress string
|
||||
|
||||
@ -17,29 +24,48 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, error) {
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
addr, ok := netip.AddrFromSlice(t.address)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||
}
|
||||
|
||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||
}
|
||||
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(t.address)},
|
||||
[]netip.Addr{},
|
||||
[]netip.Addr{addr.Unmap()},
|
||||
[]netip.Addr{dnsAddr.Unmap()},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
t.tundev = nsTunDev
|
||||
|
||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
||||
}
|
||||
if skipProxy {
|
||||
return nsTunDev, tunNet, nil
|
||||
}
|
||||
|
||||
dialer := NewNSDialer(tunNet)
|
||||
t.proxy, err = NewSocks5(dialer)
|
||||
if err != nil {
|
||||
_ = t.tundev.Close()
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
@ -49,7 +75,7 @@ func (t *NetStackTun) Create() (tun.Device, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return nsTunDev, nil
|
||||
return nsTunDev, tunNet, nil
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Close() error {
|
||||
|
@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []int{value},
|
||||
Values: []uint16{uint16(value)},
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
@ -300,6 +305,22 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
if portInfo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch portInfo.GetPortSelection().(type) {
|
||||
case *mgmProto.PortInfo_Port:
|
||||
return portInfo.GetPort() == 0
|
||||
case *mgmProto.PortInfo_Range_:
|
||||
r := portInfo.GetRange()
|
||||
return r == nil || r.Start == 0 || r.End == 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
@ -308,25 +329,12 @@ func (d *DefaultManager) addInRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@ -337,25 +345,16 @@ func (d *DefaultManager) addOutRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
@ -508,7 +507,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
@ -559,14 +558,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewall.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
Values: []uint16{uint16(int(portInfo.GetPort()))},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewall.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
@ -119,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
@ -356,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||
}
|
||||
|
||||
// GetDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDevice")
|
||||
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDevice indicates an expected call of GetDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||
}
|
||||
|
||||
// GetWGDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||
ret0, _ := ret[0].(*wgdevice.Device)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -11,7 +13,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
@ -56,6 +61,18 @@ func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*Devi
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
} else {
|
||||
log.Debug("Using system certificate pool.")
|
||||
}
|
||||
|
||||
httpTransport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -66,6 +68,12 @@ type ConfigInput struct {
|
||||
DisableServerRoutes *bool
|
||||
DisableDNS *bool
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@ -89,6 +97,12 @@ type Config struct {
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
@ -455,6 +469,33 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
|
||||
if *input.BlockLANAccess {
|
||||
log.Infof("blocking LAN access")
|
||||
} else {
|
||||
log.Infof("allowing LAN access")
|
||||
}
|
||||
config.BlockLANAccess = *input.BlockLANAccess
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
} else {
|
||||
log.Infof("enabling notifications")
|
||||
}
|
||||
config.DisableNotifications = input.DisableNotifications
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.DisableNotifications == nil {
|
||||
disabled := true
|
||||
config.DisableNotifications = &disabled
|
||||
log.Infof("setting notifications to disabled by default")
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||
updated = true
|
||||
@ -475,6 +516,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
|
||||
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
|
||||
input.DNSLabels.SafeString(),
|
||||
config.DNSLabels.SafeString())
|
||||
config.DNSLabels = input.DNSLabels
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@ -31,6 +32,7 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -59,13 +61,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run() error {
|
||||
return c.run(MobileDependency{}, nil, nil)
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@ -84,7 +81,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@ -102,18 +99,30 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
if rec != nil {
|
||||
rec.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM,
|
||||
"panic occurred",
|
||||
"The Netbird service panicked. Please restart the service and submit a bug report with the client logs.",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
nbnet.Init()
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@ -182,8 +191,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
}()
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
@ -204,8 +213,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
signalURL := fmt.Sprintf("%s://%s",
|
||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
strings.ToLower(loginResp.GetNetbirdConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetNetbirdConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||
@ -216,8 +225,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
// with the global Netbird config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetNetbirdConfig(), myPrivateKey)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@ -261,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@ -316,7 +325,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
|
||||
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
|
||||
relayCfg := loginResp.GetNetbirdConfig().GetRelay()
|
||||
if relayCfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@ -420,6 +429,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DisableServerRoutes: config.DisableServerRoutes,
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@ -443,7 +454,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
}
|
||||
|
||||
// connectToSignal creates Signal Service client and established a connection
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
var sigTLSEnabled bool
|
||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||
sigTLSEnabled = true
|
||||
@ -460,8 +471,8 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
||||
return signalClient, nil
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
@ -469,7 +480,16 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.Domain == zoneName {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
if zoneExists(config, zoneName) {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||
}
|
@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
||||
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
return ErrRouteAllWithoutNameserverGroup
|
||||
}
|
||||
|
||||
if !backupFileExist {
|
||||
@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error {
|
||||
return f.restore()
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) string() string {
|
||||
return "file"
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) backup() error {
|
||||
stats, err := os.Stat(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
const (
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 0
|
||||
PriorityDefault = 1
|
||||
)
|
||||
|
||||
type SubdomainMatcher interface {
|
||||
@ -26,7 +26,6 @@ type HandlerEntry struct {
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
MatchSubdomains bool
|
||||
}
|
||||
|
||||
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
@ -101,21 +97,33 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
pos := 0
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||
for i, h := range c.handlers {
|
||||
if h.Priority < priority {
|
||||
pos = i
|
||||
break
|
||||
// prio first
|
||||
if h.Priority < newEntry.Priority {
|
||||
return i
|
||||
}
|
||||
|
||||
// domain specificity next
|
||||
if h.Priority == newEntry.Priority {
|
||||
newDots := strings.Count(newEntry.Pattern, ".")
|
||||
existingDots := strings.Count(h.Pattern, ".")
|
||||
if newDots > existingDots {
|
||||
return i
|
||||
}
|
||||
}
|
||||
pos = i + 1
|
||||
}
|
||||
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
// add at end
|
||||
return len(c.handlers)
|
||||
}
|
||||
|
||||
// RemoveHandler removes a handler for the given pattern and priority
|
||||
@ -129,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
}
|
||||
@ -167,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("current handlers (%d):", len(handlers))
|
||||
for _, h := range handlers {
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
|
@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
dnsRouteHandler := &nbdns.MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
pattern = "*." + tt.handlerDomain[2:]
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||
}
|
||||
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||
}
|
||||
|
||||
// Create and execute request
|
||||
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
handler3 := &nbdns.MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockHandler{}
|
||||
handlers[op.priority] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
chain.AddHandler(op.pattern, handler, op.priority)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
chain.AddHandler(pattern, handler, h.priority)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
@ -677,3 +677,156 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
ops []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}
|
||||
query string
|
||||
expectedMatch string
|
||||
}{
|
||||
{
|
||||
name: "more specific domain matches first",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "more specific domain matches first, both match subdomains",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "maintain specificity order after removal",
|
||||
scenario: "after removing most specific, should fall back to less specific",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "test.sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "priority overrides specificity",
|
||||
scenario: "less specific domain with higher priority should match first",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "example.com.",
|
||||
},
|
||||
{
|
||||
name: "equal priority respects specificity",
|
||||
scenario: "with equal priority, more specific domain should match",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "specific matches before wildcard",
|
||||
scenario: "specific domain should match before wildcard at same priority",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||
handlers[op.pattern] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
}
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetReply(r)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
|
||||
} else {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user