Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association

# Conflicts:
#	management/server/account.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/policy.go
#	management/server/route.go
#	management/server/route_test.go
This commit is contained in:
bcmmbaga 2024-10-04 10:46:41 +03:00
commit 716009b791
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
357 changed files with 22111 additions and 7920 deletions

View File

@ -35,7 +35,7 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
If applicable, add the `netbird status -dA' command output.
**Do you face any client issues on desktop?**
**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.

View File

@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@ -38,7 +38,7 @@ jobs:
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...

View File

@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@ -80,7 +80,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock

View File

@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
id: go
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2

View File

@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,
ignore_words_list: erro,clienta,hastable,iif
skip: go.mod,go.sum
only_warn: 1
golangci:
@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'

View File

@ -13,6 +13,7 @@ concurrency:
jobs:
test-install-script:
strategy:
fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
@ -21,7 +22,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: run install script
env:

View File

@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v3
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init

View File

@ -3,15 +3,14 @@ name: Release
on:
push:
tags:
- 'v*'
- "v*"
branches:
- main
pull_request:
env:
SIGN_PIPE_VER: "v0.0.12"
GORELEASER_VER: "v1.14.1"
SIGN_PIPE_VER: "v0.0.14"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@ -34,20 +33,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@ -55,24 +51,19 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
-
name: Install modules
- name: Install modules
run: go mod tidy
-
name: check git status
- name: check git status
run: git --no-pager diff --exit-code
-
name: Set up QEMU
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
-
name: Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
-
name: Login to Docker hub
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
username: netbirdio
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@ -85,36 +76,32 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --rm-dist ${{ env.flags }}
args: release --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v3
- name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v3
- name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v3
- name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@ -133,17 +120,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@ -169,14 +156,14 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@ -187,20 +174,17 @@ jobs:
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@ -208,52 +192,34 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
-
name: Install modules
- name: Install modules
run: go mod tidy
-
name: check git status
- name: check git status
run: git --no-pager diff --exit-code
-
name: Run GoReleaser
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
retention-days: 3
trigger_windows_signer:
trigger_signer:
runs-on: ubuntu-latest
needs: [release,release_ui]
needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Windows binaries sign pipeline
- name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign windows bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
trigger_darwin_signer:
runs-on: ubuntu-latest
needs: [release,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Darwin App binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign darwin ui app with dispatch
workflow: Sign bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}

View File

@ -18,7 +18,31 @@ concurrency:
jobs:
test-docker-compose:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
@ -26,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -39,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@ -58,7 +82,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
@ -85,7 +110,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@ -123,6 +149,14 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules
run: go mod tidy
@ -148,6 +182,15 @@ jobs:
run: |
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
@ -159,15 +202,15 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 4
test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
@ -176,7 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@ -202,6 +245,9 @@ jobs:
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
@ -237,20 +283,5 @@ jobs:
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db
- name: test relay.env file gen CockroachDB
run: test -f relay.env

1
.gitignore vendored
View File

@ -29,4 +29,3 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
GeoLite2-City*

View File

@ -1,3 +1,5 @@
version: 2
project_name: netbird
builds:
- id: netbird
@ -22,7 +24,7 @@ builds:
goarch: 386
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@ -42,7 +44,7 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@ -64,7 +66,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-signal
dir: signal
@ -78,7 +80,21 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-relay
dir: relay
env: [CGO_ENABLED=0]
binary: netbird-relay
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- builds:
@ -86,7 +102,6 @@ archives:
- netbird-static
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
@ -161,6 +176,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
ids:
@ -313,6 +374,18 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
@ -344,10 +417,9 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
-
ids:
- ids:
- default
tap:
repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"

View File

@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui
@ -11,7 +13,7 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows
dir: client/ui
@ -26,7 +28,7 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- id: linux-arch
@ -39,7 +41,6 @@ archives:
- netbird-ui-windows
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/

View File

@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui-darwin
@ -17,7 +19,7 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@ -28,4 +30,4 @@ archives:
checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog:
skip: true
disable: true

View File

@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser**
```shell
goreleaser --snapshot --rm-dist
goreleaser build --snapshot --clean
```
**golangci-lint**
```shell

View File

@ -17,7 +17,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
</p>
@ -30,7 +30,7 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
<br/>
</strong>

View File

@ -1,4 +1,4 @@
FROM alpine:3.19
FROM alpine:3.20
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@ -15,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
)
@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile
type TunAdapter interface {
iface.TunAdapter
device.TunAdapter
}
// IFaceDiscover export internal IFaceDiscover for mobile
@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter iface.TunAdapter
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc

View File

@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
cmd.Println("Disconnected")
return nil
},
}

View File

@ -5,8 +5,8 @@ import (
"strings"
"testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)

View File

@ -4,6 +4,10 @@ import (
"fmt"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/iface"
)
func TestInitCommands(t *testing.T) {
@ -34,3 +38,44 @@ func TestInitCommands(t *testing.T) {
})
}
}
func TestSetFlagsFromEnvVars(t *testing.T) {
var cmd = &cobra.Command{
Use: "netbird",
Long: "test",
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars(cmd)
},
}
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")
t.Setenv("NB_ENABLE_ROSENPASS", "true")
t.Setenv("NB_WIREGUARD_PORT", "10000")
err := cmd.Execute()
if err != nil {
t.Fatalf("expected no error while running netbird command, got %v", err)
}
if len(natExternalIPs) != 2 {
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
}
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
}
if interfaceName != "test-name" {
t.Errorf("expected test-name, got %s", interfaceName)
}
if !rosenpassEnabled {
t.Errorf("expected rosenpassEnabled to be true, got false")
}
if wireguardPort != 10000 {
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
}
}

View File

@ -2,18 +2,21 @@ package cmd
import (
"context"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/server"
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@ -61,6 +61,8 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstance = serverInstance
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
@ -70,6 +72,14 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
if err != nil {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.cancel()
if p.serv != nil {

View File

@ -31,9 +31,9 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
@ -335,16 +335,18 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
peersConnected := 0
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
for _, pbPeerState := range peers {
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
@ -360,6 +362,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
@ -372,7 +375,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
@ -381,6 +383,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
@ -641,9 +644,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
@ -655,11 +658,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
peerState.PubKey,
peerState.Status,
peerState.ConnType,
peerState.Direct,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
@ -802,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}

View File

@ -37,7 +37,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
Direct: true,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
@ -57,7 +56,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
Direct: false,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
@ -137,7 +135,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
Direct: true,
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
@ -161,7 +158,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
Direct: false,
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
@ -283,7 +279,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"direct": true,
"iceCandidateType": {
"local": "",
"remote": ""
@ -292,6 +287,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "",
"remote": ""
},
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
@ -308,7 +304,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"direct": false,
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
@ -317,6 +312,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
@ -408,13 +404,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
direct: true
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
@ -428,13 +424,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
direct: false
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
@ -505,9 +501,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: P2P
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
@ -521,9 +517,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: Relayed
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B

View File

@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter(""))
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
@ -98,8 +98,9 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
@ -168,7 +168,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
}

View File

@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))

View File

@ -1,11 +1,13 @@
package firewall
import "github.com/netbirdio/netbird/iface"
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
Address() device.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
SetFilter(device.PacketFilter) error
}

View File

@ -19,24 +19,22 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
)
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
routingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routeingFwChainName: routeingFwChainName,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route
return m, nil
}
func (m *aclManager) AddFiltering(
func (m *aclManager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering(
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
return nil, err
}
@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering(
chain: chain,
}
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
}
}
DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
@ -203,44 +182,6 @@ func (m *aclManager) Reset() error {
return m.cleanChains()
}
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error {
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@ -338,58 +260,39 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries {
for _, rule := range rules {
if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD",
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("PREROUTING",
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@ -4,13 +4,14 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
)
// Manager of iptables firewall
@ -21,7 +22,7 @@ type Manager struct {
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *routerManager
router *router
}
// iFaceMapper defines subset methods of interface required for manager
@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouterManager(context, iptablesClient)
m.router, err = newRouter(context, iptablesClient, wgIface)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
// AddFiltering rule to the firewall
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering(
func (m *Manager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@ -73,33 +74,62 @@ func (m *Manager) AddFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(
sources [] netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeleteRule(rule)
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.InsertRoutingRules(pair)
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveRoutingRules(pair)
return m.router.RemoveNatRule(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error {
return nil
}
_, err := m.AddFiltering(
_, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@ -11,9 +11,24 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second)
@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{
Values: []int{8043: 8046},
}
rule2, err = manager.AddFiltering(
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeleteRule(r)
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddFiltering(
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeleteRule(r)
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")

View File

@ -5,16 +5,22 @@ package iptables
import (
"context"
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
)
@ -22,351 +28,455 @@ const (
const (
tableFilter = "filter"
tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
matchSet = "--match-set"
)
type routerManager struct {
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}
type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
wgIface iFaceMapper
legacyManagement bool
}
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{
r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
}
err := m.cleanUpDefaultForwardRules()
r.ipsetCounter = refcounter.New(
r.createIpSet,
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to cleanup routing rules: %s", err)
log.Errorf("cleanup routing rules: %s", err)
return nil, err
}
err = m.createContainers()
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
log.Errorf("create containers for route: %s", err)
}
return m, err
return r, err
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil {
return err
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
if err != nil {
return err
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
r.rules[string(ruleKey)] = rule
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return struct{}{}, nil
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if !pair.Masquerade {
return nil
}
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
return nil
}
// insertRoutingRule inserts an iptables rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
}
delete(i.rules, ruleKey)
return nil
}
func (i *routerManager) RouteingFwChainName() string {
return chainRTFWD
}
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
i.rules = make(map[string][]string)
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the current legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) Reset() error {
var merr *multierror.Error
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := tableFilter
if chain == chainRTNAT {
table = tableNat
}
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
}
}
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil
}
func (i *routerManager) createContainers() error {
if i.rules[Ipv4Forwarding] != nil {
return nil
}
errMSGFormat := "failed creating chain %s,error: %v"
err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
}
err = i.createChain(tableNat, chainRTNAT)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %v", chain, err)
}
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %v", err)
}
return r.addJumpRules()
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
return tableNat
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("failed to insert established rule: %v", err)
}
ruleKey := "established-" + chain
r.rules[ruleKey] = establishedRule
return nil
}
func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
r.rules[ipv4Nat] = rule
return nil
}
func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
delete(r.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
r.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
}
return ruleType
return nil
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
if inverse {
intdir = "-o"
}
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
rule = append(rule, "-j", actionToStr(params.Action))
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(p)
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(port.Values[0])}
}

View File

@ -4,11 +4,13 @@ package iptables
import (
"context"
"net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
Source: "100.100.100.1/32",
Destination: "100.100.100.0/24",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}()
err = manager.InsertRoutingRules(testCase.InputPair)
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
_, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}

View File

@ -1,15 +1,21 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
)
const (
NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s"
InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)
// Rule abstraction should be implemented by each firewall manager
@ -49,11 +55,11 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall
// AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddFiltering(
AddPeerFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
@ -64,17 +70,25 @@ type Manager interface {
comment string,
) ([]Rule, error)
// DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset() error
@ -83,6 +97,89 @@ type Manager interface {
Flush() error
}
func GenKey(format string, input string) string {
return fmt.Sprintf(format, input)
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
sortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// sortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func sortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
}

View File

@ -0,0 +1,192 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@ -1,18 +1,26 @@
package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct {
ID string
Source string
Destination string
ID route.ID
Source netip.Prefix
Destination netip.Prefix
Masquerade bool
Inverse bool
}
func GetInPair(pair RouterPair) RouterPair {
func GetInversePair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
Inverse: true,
}
}

View File

@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
)
const (
@ -33,9 +33,10 @@ const (
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
)
type AclManager struct {
@ -48,7 +49,6 @@ type AclManager struct {
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@ -64,7 +64,7 @@ type iFaceMapper interface {
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
return m, nil
}
// AddFiltering rule to the firewall
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddFiltering(
func (m *AclManager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering(
}
newRules = append(newRules, ioRule)
if !shouldAddToPrerouting(proto, dPort, direction) {
return newRules, nil
}
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
if err != nil {
return newRules, err
}
newRules = append(newRules, preroutingRule)
return newRules, nil
}
// DeleteRule from the firewall by rule definition
func (m *AclManager) DeleteRule(rule firewall.Rule) error {
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error {
return nil
}
// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
// input and output chains
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: zeroXor,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: zeroXor,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut,
})
err := m.rConn.Flush()
if err != nil {
log.Debugf("failed to create default allow rules: %s", err)
return err
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
@ -290,15 +278,11 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == firewall.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1),
})
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
protoData, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
Data: []byte{protoData},
})
}
@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else {
chain = m.chainOutputRules
}
nftRule := m.rConn.InsertRule(&nftables.Rule{
nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
})
@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil
}
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
}, nil
}
var ipExpression expr.Any
// add individual IP for match if no ipset defined
rawIP := ip.To4()
if ipset == nil {
ipExpression = &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
}
} else {
ipExpression = &expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
}
}
expressions := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
ipExpression,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
},
}
if port != nil {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*port),
},
)
}
expressions = append(expressions,
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Position: 0,
Exprs: expressions,
UserData: []byte(ruleId),
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to
m.addMarkAccept()
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return err
return fmt.Errorf(flushError, err)
}
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
m.chainPrerouting = m.createPreroutingMangle()
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
return err
}
return nil
}
@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() {
Chain: m.chainFwFilter,
Exprs: expressions,
})
expressions = []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addMarkAccept() {
// oifname "wt0" meta mark 0x000007e4 accept
// iifname "wt0" meta mark 0x000007e4 accept
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
for _, iface := range ifaces {
expressions := []expr.Any{
&expr.Meta{Key: iface, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: postroutingMark,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
}
func (m *AclManager) createChain(name string) *nftables.Chain {
@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
}
chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain
}
@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
return m.rConn.AddChain(chain)
}
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: "netbird-acl-prerouting-filter",
Table: m.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return chain
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}
func (m *AclManager) addJumpRuleToInputChain() {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if netIfName == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
} else {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
}
expressions := []expr.Any{
&expr.Meta{Key: netIfName, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if iifname == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
} else {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
}
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
func generateRuleId(
func generatePeerRuleId(
ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
@ -1155,33 +718,6 @@ func generateRuleId(
}
return "set:" + ipset.Name + rulesetID
}
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
// case of icmp port is empty
var p string
if port != nil {
p = port.String()
}
if ipset != nil {
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
} else {
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil && proto != firewall.ProtocolICMP {
return false
}
return true
}
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
copy(b, n+"\x00")
return b
}
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@ -5,9 +5,11 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
@ -15,8 +17,11 @@ import (
)
const (
// tableName is the name of the table that is used for filtering by the Netbird client
tableName = "netbird"
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
)
// Manager of iptables firewall
@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return nil, err
}
m.router, err = newRouter(context, workTable)
m.router, err = newRouter(context, workTable, wgIface)
if err != nil {
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, err
}
@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
// AddFiltering rule to the firewall
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@ -76,33 +81,52 @@ func (m *Manager) AddFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeleteRule(rule)
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair)
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveRoutingRules(pair)
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
chain = c
break
}
@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error {
return nil
}
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
oldLegacy := m.router.legacyManagement
if oldLegacy != isLegacy {
m.router.legacyManagement = isLegacy
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
if !isLegacy && oldLegacy {
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
@ -185,14 +230,16 @@ func (m *Manager) Reset() error {
}
}
m.router.ResetForwardRules()
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset forward rules: %v", err)
}
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
&expr.Verdict{},
},
UserData: []byte(allowNetbirdInputRuleID),
}
@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.InsertRule(rule)
}

View File

@ -9,14 +9,30 @@ import (
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddFiltering(
rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rules")
require.Len(t, rules, 2, "expected 2 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) {
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
for _, r := range rule {
err = manager.DeleteRule(r)
err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected 0 rules after deletion")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")

View File

@ -1,431 +0,0 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@ -0,0 +1,798 @@
package nftables
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
legacyManagement bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
// Reset cleans existing nftables default forward rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.cleanUpDefaultForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
r.acceptForwardRules()
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
})
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
}
exprs = append(exprs, &expr.Verdict{Kind: verdict})
rule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
}
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
return ruleKey, r.conn.Flush()
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
return nil
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
var elements []nftables.SetElement
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
}
}
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
}
delete(r.rules, ruleKey)
log.Debugf("removed route rule %s", ruleKey)
return nil
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
}
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{}, &expr.Masq{},
)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return
}
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
}
r.conn.InsertRule(iifRule)
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
// RemoveNatRule removes a nftables rule pair from nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: nat rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
},
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
})
}
}
return exprs
}

View File

@ -4,11 +4,15 @@ package nftables
import (
"context"
"encoding/binary"
"net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@ -24,56 +28,50 @@ const (
NFTABLES
)
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
func TestNftablesManager_AddNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
}(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
func TestNftablesManager_RemoveNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.ResetForwardRules()
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify internal rule content
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Check if the rule exists in nftables and verify its content
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err, "Failed to get rules from nftables")
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
}
require.NotNil(t, nftRule, "Rule not found in nftables")
t.Log("Actual nftables rule expressions:")
for i, expr := range nftRule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}
func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
portMatchFound = true
break
}
}
}
}
}
if payloadFound && portMatchFound {
return true
}
}
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
@ -273,7 +708,7 @@ func deleteWorkTable() {
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}

View File

@ -1,8 +1,10 @@
//go:build !android
package test
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
import (
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
InsertRuleTestCases = []struct {
@ -13,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false,
},
},
@ -22,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
@ -38,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},

View File

@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/gopacket"
@ -11,7 +12,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
const layerTypeAll = 0
@ -22,7 +24,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
@ -103,26 +105,26 @@ func (m *Manager) IsServerRouteSupported() bool {
}
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.InsertRoutingRules(pair)
return m.nativeFirewall.AddNatRule(pair)
}
// RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.RemoveRoutingRules(pair)
return m.nativeFirewall.RemoveNatRule(pair)
}
// AddFiltering rule to the firewall
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@ -188,8 +190,22 @@ func (m *Manager) AddFiltering(
return []firewall.Rule{&r}, nil
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.DeleteRouteRule(rule)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -215,6 +231,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil
}
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(_ bool) error {
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@ -395,7 +416,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
return m.DeletePeerRule(&rule)
}
}
}
@ -403,7 +424,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
return m.DeletePeerRule(&rule)
}
}
}

View File

@ -11,15 +11,16 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented")
}
@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -49,10 +50,10 @@ func TestManagerCreate(t *testing.T) {
}
}
func TestManagerAddFiltering(t *testing.T) {
func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error {
SetFilterFunc: func(device.PacketFilter) error {
isSetFilterCalled = true
return nil
},
@ -71,7 +72,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -90,7 +91,7 @@ func TestManagerAddFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -106,7 +107,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -119,14 +120,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeleteRule(r)
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@ -140,7 +141,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
err = m.DeleteRule(r)
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -252,7 +253,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -290,7 +291,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) {
// creating mock iface
iface := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// creating manager instance
@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
require.NoError(t, err)
@ -406,9 +407,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")

View File

@ -0,0 +1,5 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
package iface
package configurer
import (
"fmt"
@ -12,18 +12,17 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wgKernelConfigurer struct {
type KernelConfigurer struct {
deviceName string
}
func newWGConfigurer(deviceName string) wgConfigurer {
wgc := &wgKernelConfigurer{
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
return &KernelConfigurer{
deviceName: deviceName,
}
return wgc
}
func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error {
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
return nil
}
func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@ -57,7 +56,8 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
@ -74,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return nil
}
func (c *wgKernelConfigurer) removePeer(peerKey string) error {
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@ -95,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error {
return nil
}
func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@ -122,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro
return nil
}
func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
@ -164,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e
return nil
}
func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
@ -188,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer
return wgtypes.Peer{}, ErrPeerNotFound
}
func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
@ -204,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
return wg.ConfigureDevice(c.deviceName, config)
}
func (c *wgKernelConfigurer) close() {
func (c *KernelConfigurer) Close() {
}
func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) {
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)

View File

@ -1,6 +1,6 @@
//go:build linux || windows || freebsd
package iface
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0"

View File

@ -1,6 +1,6 @@
//go:build darwin
package iface
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100"

View File

@ -1,6 +1,6 @@
//go:build !windows
package iface
package configurer
import (
"net"

View File

@ -1,4 +1,4 @@
package iface
package configurer
import (
"net"

View File

@ -1,4 +1,4 @@
package iface
package configurer
import (
"encoding/hex"
@ -19,15 +19,15 @@ import (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type wgUSPConfigurer struct {
type WGUSPConfigurer struct {
device *device.Device
deviceName string
uapiListener net.Listener
}
func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
wgCfg := &wgUSPConfigurer{
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
}
@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
return wgCfg
}
func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error {
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@ -65,7 +65,8 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
@ -79,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *wgUSPConfigurer) removePeer(peerKey string) error {
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@ -96,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@ -120,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
@ -184,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *wgUSPConfigurer) startUAPI() {
func (t *WGUSPConfigurer) startUAPI() {
var err error
t.uapiListener, err = openUAPI(t.deviceName)
if err != nil {
@ -206,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() {
}(t.uapiListener)
}
func (t *wgUSPConfigurer) close() {
func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil {
err := t.uapiListener.Close()
if err != nil {
@ -222,7 +223,7 @@ func (t *wgUSPConfigurer) close() {
}
}
func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) {
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)

View File

@ -1,4 +1,4 @@
package iface
package configurer
import (
"encoding/hex"

View File

@ -0,0 +1,9 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

18
client/iface/device.go Normal file
View File

@ -0,0 +1,18 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@ -1,4 +1,4 @@
package iface
package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {

View File

@ -1,18 +1,18 @@
package iface
package device
import (
"fmt"
"net"
)
// WGAddress Wireguard parsed address
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func parseWGAddress(address string) (WGAddress, error) {
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err

View File

@ -1,4 +1,4 @@
package iface
package device
type MobileIFaceArguments struct {
TunAdapter TunAdapter // only for Android

View File

@ -1,7 +1,6 @@
//go:build android
// +build android
package iface
package device
import (
"strings"
@ -12,11 +11,12 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform
type wgTunDevice struct {
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address WGAddress
port int
key string
@ -26,13 +26,13 @@ type wgTunDevice struct {
name string
device *device.Device
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice {
return wgTunDevice{
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet
}
}
func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) {
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
log.Info("create tun interface")
routesString := routesToString(routes)
@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
t.name = name
t.wrapper = newDeviceWrapper(tunDevice)
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.close()
t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *wgTunDevice) UpdateAddr(addr WGAddress) error {
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
func (t *wgTunDevice) Close() error {
func (t *WGTunDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error {
return nil
}
func (t *wgTunDevice) Device() *device.Device {
func (t *WGTunDevice) Device() *device.Device {
return t.device
}
func (t *wgTunDevice) DeviceName() string {
func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *wgTunDevice) WgAddress() WGAddress {
func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}
func (t *wgTunDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func routesToString(routes []string) string {

View File

@ -1,8 +1,9 @@
//go:build !ios
package iface
package device
import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
@ -10,10 +11,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type tunDevice struct {
type TunDevice struct {
name string
address WGAddress
port int
@ -22,13 +24,13 @@ type tunDevice struct {
iceBind *bind.ICEBind
device *device.Device
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
@ -38,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
func (t *tunDevice) Create() (wgConfigurer, error) {
func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunDevice)
t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.wrapper,
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@ -55,20 +57,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@ -83,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *tunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *tunDevice) Close() error {
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -104,20 +106,20 @@ func (t *tunDevice) Close() error {
return nil
}
func (t *tunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunDevice) DeviceName() string {
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *tunDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *tunDevice) assignAddr() error {
func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)

View File

@ -1,4 +1,4 @@
package iface
package device
import (
"net"
@ -28,22 +28,23 @@ type PacketFilter interface {
SetNetwork(*net.IPNet)
}
// DeviceWrapper to override Read or Write of packets
type DeviceWrapper struct {
// FilteredDevice to override Read or Write of packets
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceWrapper constructor function
func newDeviceWrapper(device tun.Device) *DeviceWrapper {
return &DeviceWrapper{
// newDeviceFilter constructor function
func newDeviceFilter(device tun.Device) *FilteredDevice {
return &FilteredDevice{
Device: device,
}
}
// Read wraps read method with filtering feature
func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
// Write wraps write method with filtering feature
func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
}
// SetFilter sets packet filter to device
func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()

View File

@ -1,4 +1,4 @@
package iface
package device
import (
"net"
@ -7,7 +7,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
mocks "github.com/netbirdio/netbird/iface/mocks"
mocks "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestDeviceWrapperRead(t *testing.T) {
@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
wrapped := newDeviceWrapper(tun)
wrapped := newDeviceFilter(tun)
bufs := [][]byte{{}}
sizes := []int{0}
@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
wrapped := newDeviceWrapper(tun)
wrapped := newDeviceFilter(tun)
bufs := [][]byte{buffer.Bytes()}
@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()}
@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{{}}

View File

@ -1,7 +1,7 @@
//go:build ios
// +build ios
package iface
package device
import (
"os"
@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type tunDevice struct {
type TunDevice struct {
name string
address WGAddress
port int
@ -24,13 +25,13 @@ type tunDevice struct {
tunFd int
device *device.Device
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice {
return &tunDevice{
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor
}
}
func (t *tunDevice) Create() (wgConfigurer, error) {
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, err
}
t.wrapper = newDeviceWrapper(tunDevice)
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.close()
t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *tunDevice) Device() *device.Device {
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *tunDevice) DeviceName() string {
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *tunDevice) Close() error {
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -119,15 +120,15 @@ func (t *tunDevice) Close() error {
return nil
}
func (t *tunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunDevice) UpdateAddr(addr WGAddress) error {
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
func (t *tunDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
package iface
package device
import (
"context"
@ -10,11 +10,12 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock"
)
type tunKernelDevice struct {
type TunKernelDevice struct {
name string
address WGAddress
wgPort int
@ -31,11 +32,11 @@ type tunKernelDevice struct {
filterFn bind.FilterFn
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{
return &TunKernelDevice{
ctx: ctx,
ctxCancel: cancel,
name: name,
@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
}
}
func (t *tunKernelDevice) Create() (wgConfigurer, error) {
func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name)
if err := link.recreate(); err != nil {
@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("set mtu: %w", err)
}
configurer := newWGConfigurer(t.name)
configurer := configurer.NewKernelConfigurer(t.name)
if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
return nil, err
if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
}
func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *tunKernelDevice) UpdateAddr(address WGAddress) error {
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *tunKernelDevice) Close() error {
func (t *TunKernelDevice) Close() error {
if t.link == nil {
return nil
}
@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error {
return closErr
}
func (t *tunKernelDevice) WgAddress() WGAddress {
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunKernelDevice) DeviceName() string {
func (t *TunKernelDevice) DeviceName() string {
return t.name
}
func (t *tunKernelDevice) Wrapper() *DeviceWrapper {
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
// assignAddr Adds IP address to the tunnel interface
func (t *tunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}

View File

@ -1,7 +1,7 @@
//go:build !android
// +build !android
package iface
package device
import (
"fmt"
@ -10,11 +10,12 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
)
type tunNetstackDevice struct {
type TunNetstackDevice struct {
name string
address WGAddress
port int
@ -24,14 +25,14 @@ type tunNetstackDevice struct {
iceBind *bind.ICEBind
device *device.Device
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice {
return &tunNetstackDevice{
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
@ -42,33 +43,33 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string
}
}
func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
t.wrapper,
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
return t.configurer, nil
}
func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *tunNetstackDevice) UpdateAddr(WGAddress) error {
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
func (t *tunNetstackDevice) Close() error {
func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error {
return nil
}
func (t *tunNetstackDevice) WgAddress() WGAddress {
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
func (t *tunNetstackDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
package iface
package device
import (
"fmt"
@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type tunUSPDevice struct {
type USPDevice struct {
name string
address WGAddress
port int
@ -24,38 +25,37 @@ type tunUSPDevice struct {
iceBind *bind.ICEBind
device *device.Device
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
return &tunUSPDevice{
return &USPDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
iceBind: bind.NewICEBind(transportNet, filterFn)}
}
func (t *tunUSPDevice) Create() (wgConfigurer, error) {
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err)
return nil, err
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.wrapper,
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@ -63,20 +63,20 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *tunUSPDevice) UpdateAddr(address WGAddress) error {
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *tunUSPDevice) Close() error {
func (t *USPDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error {
return nil
}
func (t *tunUSPDevice) WgAddress() WGAddress {
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunUSPDevice) DeviceName() string {
func (t *USPDevice) DeviceName() string {
return t.name
}
func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error {
func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)

View File

@ -1,4 +1,4 @@
package iface
package device
import (
"fmt"
@ -11,10 +11,13 @@ import (
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type tunDevice struct {
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address WGAddress
port int
@ -24,13 +27,13 @@ type tunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
wrapper *DeviceWrapper
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer
configurer WGConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
@ -40,18 +43,31 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
func getGUID() (windows.GUID, error) {
guidString := defaultWindowsGUIDSTring
if CustomWindowsGUIDString != "" {
guidString = CustomWindowsGUIDString
}
return windows.GUIDFromString(guidString)
}
func (t *TunDevice) Create() (WGConfigurer, error) {
guid, err := getGUID()
if err != nil {
log.Errorf("failed to get GUID: %s", err)
return nil, err
}
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.wrapper = newDeviceWrapper(tunDevice)
t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.wrapper,
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@ -74,20 +90,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@ -102,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *tunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *tunDevice) Close() error {
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.close()
t.configurer.Close()
}
if t.device != nil {
@ -123,19 +139,19 @@ func (t *tunDevice) Close() error {
}
return nil
}
func (t *tunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunDevice) DeviceName() string {
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *tunDevice) Wrapper() *DeviceWrapper {
return t.wrapper
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *tunDevice) getInterfaceGUIDString() (string, error) {
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
@ -149,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) {
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *tunDevice) assignAddr() error {
func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})

View File

@ -0,0 +1,20 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@ -1,6 +1,6 @@
//go:build (!linux && !freebsd) || android
package iface
package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {

View File

@ -1,4 +1,4 @@
package iface
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool {
return false
}
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool {
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true

View File

@ -1,7 +1,7 @@
//go:build linux && !android
// Package iface provides wireguard network interface creation and management
package iface
package device
import (
"bufio"
@ -66,8 +66,8 @@ func getModuleRoot() string {
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool {
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true

View File

@ -1,4 +1,6 @@
package iface
//go:build linux && !android
package device
import (
"bufio"

View File

@ -1,10 +1,11 @@
package iface
package device
import (
"fmt"
"github.com/netbirdio/netbird/iface/freebsd"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
)
type wgLink struct {

View File

@ -1,6 +1,6 @@
//go:build linux && !android
package iface
package device
import (
"fmt"

View File

@ -1,4 +1,4 @@
package iface
package device
import (
"os"

View File

@ -0,0 +1,4 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@ -0,0 +1,16 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@ -9,28 +9,27 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
DefaultMTU = 1280
DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault
)
// WGIface represents a interface instance
type WGAddress = device.WGAddress
// WGIface represents an interface instance
type WGIface struct {
tun wgTunDevice
tun WGTunDevice
userspaceBind bool
mu sync.Mutex
configurer wgConfigurer
filter PacketFilter
}
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
configurer device.WGConfigurer
filter device.PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@ -44,7 +43,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() WGAddress {
func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := parseWGAddress(newAddr)
addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}
@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.removePeer(peerKey)
return w.configurer.RemovePeer(peerKey)
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.addAllowedIP(peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
@ -117,34 +116,50 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.removeAllowedIP(peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
}
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Close()
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
}
err = w.waitUntilRemoved()
if err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
}
log.Infof("interface %s successfully removed", w.Name())
}
return nil
}
// SetFilter sets packet filters for the userspace implementation
func (w *WGIface) SetFilter(filter PacketFilter) error {
func (w *WGIface) SetFilter(filter device.PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun.Wrapper() == nil {
if w.tun.FilteredDevice() == nil {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.Wrapper().SetFilter(filter)
w.tun.FilteredDevice().SetFilter(filter)
return nil
}
// GetFilter returns packet filter used by interface if it uses userspace device implementation
func (w *WGIface) GetFilter() PacketFilter {
func (w *WGIface) GetFilter() device.PacketFilter {
w.mu.Lock()
defer w.mu.Unlock()
@ -152,14 +167,41 @@ func (w *WGIface) GetFilter() PacketFilter {
}
// GetDevice to interact with raw device (with filtering)
func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetDevice() *device.FilteredDevice {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Wrapper()
return w.tun.FilteredDevice()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)
defer timeout.Stop()
for {
iface, err := net.InterfaceByName(w.Name())
if err != nil {
if _, ok := err.(*net.OpError); ok {
log.Infof("interface %s has been removed", w.Name())
return nil
}
log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
} else if iface == nil {
log.Infof("interface %s has been removed", w.Name())
return nil
}
select {
case <-timeout.C:
return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
default:
time.Sleep(100 * time.Millisecond)
}
}
}

View File

@ -5,18 +5,19 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (!android && !darwin) || ios
package iface

View File

@ -0,0 +1,67 @@
//go:build !ios
package iface
import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
backOff := &backoff.ExponentialBackOff{
InitialInterval: 20 * time.Millisecond,
MaxElapsedTime: 500 * time.Millisecond,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
cfgr, err := w.tun.Create()
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
return backoff.Retry(operation, backOff)
}

View File

@ -0,0 +1,17 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

View File

@ -0,0 +1,22 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

View File

@ -0,0 +1,9 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

View File

@ -0,0 +1,32 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@ -7,17 +7,18 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil

105
client/iface/iface_moc.go Normal file
View File

@ -0,0 +1,105 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}

View File

@ -4,14 +4,18 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
)
// keep darwin compatibility
@ -174,6 +178,72 @@ func Test_Close(t *testing.T) {
}
}
func TestRecreation(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("down-%d", i), func(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
for {
_, err = net.InterfaceByName(ifaceName)
if err != nil {
t.Logf("interface %s not found: err: %s", ifaceName, err)
break
}
t.Logf("interface %s found", ifaceName)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
wg, err := wgctrl.New()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
for {
_, err = net.InterfaceByName(ifaceName)
if err == nil {
t.Logf("interface %s found", ifaceName)
break
}
t.Logf("interface %s not found: err: %s", ifaceName, err)
}
start := time.Now()
err = iface.Close()
t.Logf("down time: %s", time.Since(start))
if err != nil {
t.Fatal(err)
}
})
}
}
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
@ -345,6 +415,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
@ -364,6 +437,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)

View File

@ -8,13 +8,14 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if WireGuardModuleIsLoaded() {
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !tunModuleIsLoaded() {
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}

View File

@ -5,13 +5,14 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*tunDevice).getInterfaceGUIDString()
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@ -0,0 +1,34 @@
//go:build !windows
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
}

Some files were not shown because too many files have changed in this diff Show More