Merge branch 'main' into local/engine-restart

This commit is contained in:
Pascal Fischer 2023-09-21 16:43:03 +02:00
commit cdbe9c4eef
168 changed files with 3724 additions and 2522 deletions

View File

@ -0,0 +1,41 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@ -15,14 +15,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Cache Go modules
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@ -18,13 +18,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -32,7 +32,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
@ -47,13 +47,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -61,7 +61,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev

View File

@ -8,14 +8,13 @@ jobs:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Checkout code
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
args: --timeout=6m
uses: golangci/golangci-lint-action@v3

View File

@ -0,0 +1,36 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@ -1,60 +0,0 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@ -1,38 +0,0 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@ -19,7 +19,7 @@ on:
- '**/Dockerfile.*'
env:
SIGN_PIPE_VER: "v0.0.8"
SIGN_PIPE_VER: "v0.0.9"
GORELEASER_VER: "v1.14.1"
concurrency:
@ -29,20 +29,24 @@ concurrency:
jobs:
release:
runs-on: ubuntu-latest
env:
flags: ""
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v1
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -56,10 +60,10 @@ jobs:
run: git --no-pager diff --exit-code
-
name: Set up QEMU
uses: docker/setup-qemu-action@v1
uses: docker/setup-qemu-action@v2
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
uses: docker/setup-buildx-action@v2
-
name: Login to Docker hub
if: github.event_name != 'pull_request'
@ -82,10 +86,10 @@ jobs:
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --rm-dist
args: release --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@ -93,7 +97,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: release
path: dist/
@ -102,17 +106,19 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20"
- name: Cache Go modules
uses: actions/cache@v1
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@ -132,17 +138,17 @@ jobs:
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: release-ui
path: dist/
@ -151,19 +157,21 @@ jobs:
release_ui_darwin:
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v1
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@ -175,15 +183,15 @@ jobs:
-
name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v2
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: release-ui-darwin
path: dist/

View File

@ -9,7 +9,6 @@ on:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
@ -25,12 +24,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@ -81,6 +80,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
@ -92,6 +92,7 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
@ -121,7 +122,7 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4
working-directory: infrastructure_files

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ client/.distfiles/
infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store

54
.golangci.yaml Normal file
View File

@ -0,0 +1,54 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
- path: sharedsock/filter.go
linters:
- unused
- path: client/firewall/iptables/rule.go
linters:
- unused
- path: mock.go
linters:
- nilnil

View File

@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New Release! Peer expiration.</strong>
<a href="https://github.com/netbirdio/netbird/releases">
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
Learn more
</a>
</p>
@ -24,7 +24,7 @@
<p align="center">
<strong>
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
@ -40,9 +40,13 @@
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird isolates every machine and device by applying granular access policies, while allowing you to manage them intuitively from a single place.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Key features:**
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Key features
| Connectivity | Management | Automation | Platforms |
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
@ -57,10 +61,6 @@
| | <ul><li> - \[x] SSH access management </ul></li> | | |
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@ -18,10 +18,9 @@ func Encode(num uint32) string {
}
var encoded strings.Builder
remainder := uint32(0)
for num > 0 {
remainder = num % base
remainder := num % base
encoded.WriteByte(alphabet[remainder])
num /= base
}

View File

@ -1,7 +1,5 @@
FROM gcr.io/distroless/base:debug
FROM alpine:3
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
SHELL ["/busybox/sh","-c"]
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@ -55,7 +55,6 @@ type Client struct {
ctxCancelLock *sync.Mutex
deviceName string
routeListener routemanager.RouteListener
onHostDnsFn func([]string)
}
// NewClient instantiate a new Client
@ -97,7 +96,30 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}

View File

@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}

View File

@ -3,21 +3,19 @@ package cmd
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"strings"
"time"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
var loginCmd = &cobra.Command{
@ -191,17 +189,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if userCode != "" {
if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
err := open.Run(verificationURIComplete)
cmd.Printf("Please do the SSO login in your browser. \n" +
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
" " + verificationURIComplete + " " + codeMsg + " \n\n")
if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
verificationURIComplete + " " + codeMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}

View File

@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(context.Background())
resp, _ := getStatus(ctx, cmd)
resp, err := getStatus(ctx, cmd)
if err != nil {
return nil
return err
}
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
" netbird up \n\n"+
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
resp.GetStatus(),
)
return nil
@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
outputInformationHolder := convertToStatusOutputOverview(resp)
statusOutputString := ""
var statusOutputString string
switch {
case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder)

View File

@ -76,12 +76,12 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil
}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -40,6 +40,9 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set

View File

@ -44,6 +44,7 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
type ruleset struct {
@ -52,7 +53,7 @@ type ruleset struct {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
rulesets: make(map[string]ruleset),
}
if err := ipset.Init(); err != nil {
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client
if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
if m.ipv4Client == nil && m.ipv6Client == nil {
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
}
if err := m.Reset(); err != nil {
@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall
//
// If comment is empty rule ID is used as comment
@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}

View File

@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)
@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)
@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)

View File

@ -29,6 +29,8 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@ -379,7 +381,7 @@ func (m *Manager) chain(
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, name, hook, priority, cType)
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
}
if ip.To4() != nil {
@ -399,13 +401,20 @@ func (m *Manager) chain(
}
// table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil {
return nil, err
}
@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil {
return nil, err
}
@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
if t.Name == tableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
tableName string,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family)
table, err := m.table(family, tableName)
if err != nil {
return nil, err
}
@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
}
@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
return nil
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))

View File

@ -0,0 +1,19 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@ -0,0 +1,21 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@ -0,0 +1,91 @@
package uspfilter
import (
"errors"
"fmt"
"os/exec"
"strings"
"syscall"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
noRulesMatchCriteria = "No rules match the specified criteria"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, args ...string) error {
active, err := isFirewallRuleActive(ruleName)
if err != nil {
return err
}
if (action == addRule && !active) || (action == deleteRule && active) {
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
args := append(baseArgs, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
return nil
}
func isFirewallRuleActive(ruleName string) (bool, error) {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// if the firewall rule is not active, we expect last exit code to be 1
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
if exitStatus == 1 {
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
}
}
return false, err
}
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
return true, nil
}

View File

@ -19,6 +19,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key
@ -30,6 +31,8 @@ type Manager struct {
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex
}
@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
},
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
}
if err := iface.SetFilter(m); err != nil {
@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
return fmt.Errorf("hook with given id not found")
}
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@ -16,6 +16,7 @@ import (
type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },

View File

@ -146,12 +146,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName = ipset.name
ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)

View File

@ -1,4 +1,4 @@
//go:build !linux
//go:build !linux || android
package acl
@ -6,6 +6,8 @@ import (
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil
}
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)

View File

@ -1,3 +1,5 @@
//go:build !android
package acl
import (
@ -7,26 +9,68 @@ import (
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
// Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
if fm, err = uspfilter.Create(iface); err != nil {
usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err
}
} else {
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager: %s", err)
// fallback to iptables
if fm, err = iptables.Create(iface); err != nil {
log.Errorf("failed to create iptables manager: %s", err)
return nil, err
}
// set kernel space firewall Reset as hook for userspace firewall
// manager Reset method, to clean up
if resetHookForUserspace != nil {
usfm.SetResetHook(resetHookForUserspace)
}
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err
}
return newDefaultManager(fm), nil

View File

@ -1,11 +1,13 @@
package acl
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface)
acl, err := Create(ifaceMock)
if err != nil {
t.Errorf("create ACL manager: %v", err)
return
@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface)
acl, err := Create(ifaceMock)
if err != nil {
t.Errorf("create ACL manager: %v", err)
return

View File

@ -4,8 +4,8 @@ import (
"context"
"fmt"
"net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@ -57,34 +57,50 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("getting device authorization flow info")
// Try to initialize the Device Authorization Flow
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
return authenticateWithDeviceCodeFlow(ctx, config)
}
log.Debugf("getting device authorization flow info failed with error: %v", err)
log.Debugf("falling back to pkce authorization flow info")
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// If Device Authorization Flow failed, try the PKCE Authorization Flow
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
"Please proceed with setting up this device using setup keys " +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
} else {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@ -5,6 +5,7 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"html/template"
"net"
@ -78,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
@ -112,60 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan)
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.handleOAuthToken(token)
return p.parseOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
port := parsedURL.Port()
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
defer func() {
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
}
}()
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
@ -176,12 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
tokenChan <- token
})
if err := server.ListenAndServe(); err != nil {
server.Handler = mux
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,

View File

@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"strings"
)
@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
return fmt.Errorf("invalid JWT token audience field")
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@ -0,0 +1,3 @@
//go:build !linux || android
package checkfw

View File

@ -0,0 +1,56 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
if err != nil {
return
}
managementURL := "https://test.management.url:33071"
adminURL := "https://app.admin.url:443"
path := filepath.Join(t.TempDir(), "config.json")

View File

@ -192,8 +192,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done()
statusRecorder.ClientTeardown()
@ -214,6 +212,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return nil
}
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)

View File

@ -238,7 +238,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.routeAll = false
}

View File

@ -777,7 +777,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, nil
return nil, err
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)

View File

@ -11,6 +11,9 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
@ -24,10 +27,11 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
runtimeIP string
runtimePort int
listenIP string
listenPort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
@ -43,6 +47,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
UDPSize: 65535,
},
}
return s
}
@ -55,13 +60,21 @@ func (s *serviceViaListener) Listen() error {
}
var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
s.listenIP, s.listenPort, err = s.evalListenAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
if s.shouldApplyPortFwd() {
s.ebpfService = ebpf.GetEbpfManagerInstance()
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
if err != nil {
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
s.ebpfService = nil
}
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
@ -69,9 +82,10 @@ func (s *serviceViaListener) Listen() error {
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
}
}()
return nil
}
@ -90,6 +104,13 @@ func (s *serviceViaListener) Stop() {
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@ -101,11 +122,18 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
} else {
return s.listenPort
}
}
func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
@ -136,10 +164,30 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
// resolution won't work.
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
// traffic on port 53 and forward it to a local DNS server running on 5053.
func (s *serviceViaListener) shouldApplyPortFwd() bool {
if runtime.GOOS != "linux" {
return false
}
if s.customAddr != nil {
return false
}
if s.listenPort == defaultPort {
return false
}
return true
}

View File

@ -54,13 +54,16 @@ type bpfSpecs struct {
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"`
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"`
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbWgProxy,
p.NbXdpProg,
)
}

Binary file not shown.

View File

@ -54,13 +54,16 @@ type bpfSpecs struct {
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"`
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"`
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbWgProxy,
p.NbXdpProg,
)
}

Binary file not shown.

View File

@ -0,0 +1,51 @@
package ebpf
import (
"encoding/binary"
"net"
log "github.com/sirupsen/logrus"
)
const (
mapKeyDNSIP uint32 = 0
mapKeyDNSPort uint32 = 1
)
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagDnsForwarder)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeDNSFwd() error {
log.Debugf("free ebpf DNS forwarder")
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
}
func ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@ -0,0 +1,116 @@
package ebpf
import (
_ "embed"
"net"
"sync"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
mapKeyFeatures uint32 = 0
featureFlagWGProxy = 0b00000001
featureFlagDnsForwarder = 0b00000010
)
var (
singleton manager.Manager
singletonLock = &sync.Mutex{}
)
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
type GeneralManager struct {
lock sync.Mutex
link link.Link
featureFlags uint16
bpfObjs bpfObjects
}
// GetEbpfManagerInstance return a static eBpf Manager instance
func GetEbpfManagerInstance() manager.Manager {
singletonLock.Lock()
defer singletonLock.Unlock()
if singleton != nil {
return singleton
}
singleton = &GeneralManager{}
return singleton
}
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
tf.featureFlags = tf.featureFlags | feature
}
func (tf *GeneralManager) loadXdp() error {
if tf.link != nil {
return nil
}
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
iFace, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// load pre-compiled programs into the kernel.
err = loadBpfObjects(&tf.bpfObjs, nil)
if err != nil {
return err
}
tf.link, err = link.AttachXDP(link.XDPOptions{
Program: tf.bpfObjs.NbXdpProg,
Interface: iFace.Index,
})
if err != nil {
_ = tf.bpfObjs.Close()
tf.link = nil
return err
}
return nil
}
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
tf.lock.Lock()
defer tf.lock.Unlock()
tf.featureFlags &^= feature
if tf.link == nil {
return nil
}
if tf.featureFlags == 0 {
return tf.close()
}
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
}
func (tf *GeneralManager) close() error {
log.Debugf("detach ebpf program ")
err := tf.bpfObjs.Close()
if err != nil {
log.Warnf("failed to close eBpf objects: %s", err)
}
err = tf.link.Close()
tf.link = nil
return err
}

View File

@ -0,0 +1,40 @@
package ebpf
import (
"testing"
)
func TestManager_setFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
if mgr.featureFlags != 1 {
t.Errorf("invalid faeture state")
}
mgr.setFeatureFlag(featureFlagDnsForwarder)
if mgr.featureFlags != 3 {
t.Errorf("invalid faeture state")
}
}
func TestManager_unsetFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
mgr.setFeatureFlag(featureFlagDnsForwarder)
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 2 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags)
}
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 0 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags)
}
}

View File

@ -0,0 +1,64 @@
const __u32 map_key_dns_ip = 0;
const __u32 map_key_dns_port = 1;
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u32),
.max_entries = 10,
};
struct bpf_map_def SEC("maps") nb_map_dns_port = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__be32 dns_ip = 0;
__be16 dns_port = 0;
// 13568 is 53 in big endian
__be16 GENERAL_DNS_PORT = 13568;
bool read_settings() {
__u16 *port_value;
__u32 *ip_value;
// read dns ip
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
if(!ip_value) {
return false;
}
dns_ip = htonl(*ip_value);
// read dns port
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
if (!port_value) {
return false;
}
dns_port = htons(*port_value);
return true;
}
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
if (dns_port == 0) {
if(!read_settings()){
return XDP_PASS;
}
bpf_printk("dns port: %d", ntohs(dns_port));
bpf_printk("dns ip: %d", ntohl(dns_ip));
}
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
udp->dest = dns_port;
return XDP_PASS;
}
if (udp->source == dns_port && ip->saddr == dns_ip) {
udp->source = GENERAL_DNS_PORT;
return XDP_PASS;
}
return XDP_PASS;
}

View File

@ -0,0 +1,66 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "dns_fwd.c"
#include "wg_proxy.c"
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u16 flag_feature_wg_proxy = 0b01;
const __u16 flag_feature_dns_fwd = 0b10;
const __u32 map_key_features = 0;
struct bpf_map_def SEC("maps") nb_features = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
SEC("xdp")
int nb_xdp_prog(struct xdp_md *ctx) {
__u16 *features;
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
if (!features) {
return XDP_PASS;
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
// skip non UPD packages
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
if (*features & flag_feature_dns_fwd) {
xdp_dns_fwd(ip, udp);
}
if (*features & flag_feature_wg_proxy) {
xdp_wg_proxy(ip, udp);
}
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@ -0,0 +1,54 @@
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
if (!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
if (!value) {
return false;
}
wg_port = htons(*value);
return true;
}
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
if (proxy_port == 0 || wg_port == 0) {
if (!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != wg_port){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}

View File

@ -0,0 +1,41 @@
package ebpf
import log "github.com/sirupsen/logrus"
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
log.Debugf("load ebpf WG proxy")
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagWGProxy)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeWGProxy() error {
log.Debugf("free ebpf WG proxy")
return tf.unsetFeatureFlag(featureFlagWGProxy)
}

View File

@ -0,0 +1,15 @@
//go:build !android
package ebpf
import (
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
// panic on older Android version.
func GetEbpfManagerInstance() manager.Manager {
return ebpf.GetEbpfManagerInstance()
}

View File

@ -0,0 +1,10 @@
//go:build !linux || android
package ebpf
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
// GetEbpfManagerInstance return error because ebpf is not supported on all os
func GetEbpfManagerInstance() manager.Manager {
panic("unsupported os")
}

View File

@ -0,0 +1,9 @@
package manager
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
type Manager interface {
LoadDNSFwd(ip string, dnsPort int) error
FreeDNSFwd() error
LoadWgProxy(proxyPort, wgPort int) error
FreeWGProxy() error
}

View File

@ -995,14 +995,12 @@ func (e *Engine) parseNATExternalIPMappings() []string {
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
break
}
if externalIP != nil {
mappedIP := externalIP.String()
if internalIP != nil {
mappedIP = mappedIP + "/" + internalIP.String()
}
mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
mappedIP := externalIP.String()
if internalIP != nil {
mappedIP = mappedIP + "/" + internalIP.String()
}
mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
}
if len(mappedIPs) != len(e.config.NATExternalIPs) {
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")

View File

@ -1046,15 +1046,15 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", nil
return nil, "", err
}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@ -61,7 +61,7 @@ func (n *notifier) clientStart() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = true
n.lastNotification = stateConnected
n.lastNotification = stateConnecting
n.notify(n.lastNotification)
}
@ -114,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
return stateConnected
}
if !managementConn && !signalConn {
if !managementConn && !signalConn && !n.currentClientState {
return stateDisconnected
}

View File

@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus != peer.StatusConnected {
if err != nil {
return err
}
if state.ConnStatus != peer.StatusConnected {
return nil
}

View File

@ -7,6 +7,8 @@ import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
const (
@ -26,15 +28,20 @@ func genKey(format string, input string) string {
return fmt.Sprintf(format, input)
}
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager {
manager, err := newNFTablesManager(parentCTX)
if err == nil {
log.Debugf("nftables firewall manager will be used")
return manager
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func newFirewall(parentCTX context.Context) (firewallManager, error) {
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
}
log.Debugf("fallback to iptables firewall manager: %s", err)
return newIptablesManager(parentCTX)
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
}
func getInPair(pair routerPair) routerPair {

View File

@ -3,24 +3,13 @@
package routemanager
import "context"
import (
"context"
"fmt"
"runtime"
)
type unimplementedFirewall struct{}
func (unimplementedFirewall) RestoreOrCreateContainers() error {
return nil
}
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) CleanRoutingRules() {
}
// NewFirewall returns an unimplemented Firewall manager
func NewFirewall(parentCtx context.Context) firewallManager {
return unimplementedFirewall{}
// newFirewall returns a nil manager
func newFirewall(context.Context) (firewallManager, error) {
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
}

View File

@ -49,30 +49,28 @@ type iptablesManager struct {
mux sync.Mutex
}
func newIptablesManager(parentCtx context.Context) *iptablesManager {
ctx, cancel := context.WithCancel(parentCtx)
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
log.Debugf("failed to initialize iptables for ipv4: %s", err)
} else if !isIptablesClientAvailable(ipv4Client) {
log.Infof("iptables is missing for ipv4")
ipv4Client = nil
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Debugf("failed to initialize iptables for ipv6: %s", err)
} else if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
}
return &iptablesManager{
ctx, cancel := context.WithCancel(parentCtx)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent
@ -395,6 +393,10 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey]
@ -459,6 +461,10 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
@ -479,8 +485,3 @@ func getIptablesRuleType(table string) string {
}
return ruleType
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow()
}
manager := newIptablesManager(context.TODO())
manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")

View File

@ -27,7 +27,7 @@ type DefaultManager struct {
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter *serverRouter
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
@ -36,13 +36,17 @@ type DefaultManager struct {
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil {
log.Errorf("server router is not supported: %s", err)
}
mCTX, cancel := context.WithCancel(ctx)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface),
serverRouter: srvRouter,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@ -59,7 +63,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.cleanUp()
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
m.ctx = nil
}
@ -77,9 +83,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
}
return nil

View File

@ -3,11 +3,12 @@ package routemanager
import (
"context"
"fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip"
"runtime"
"testing"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/peer"
@ -30,7 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
shouldCheckServerRoutes bool
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
}{
@ -87,7 +88,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0,
},
@ -116,10 +116,38 @@ func TestManagerUpdateRoutes(t *testing.T) {
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
removeSrvRouter: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
},
{
name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{
@ -174,25 +202,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "No Server Routes Should Be Added To Non Linux",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("1.2.3.4/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS != "linux",
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove 1 Client Route",
inputInitRoutes: []*route.Route{
@ -335,7 +344,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
},
inputRoutes: []*route.Route{},
inputSerial: 1,
shouldCheckServerRoutes: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
@ -384,7 +392,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 1,
},
@ -409,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop()
if testCase.removeSrvRouter {
routeManager.serverRouter = nil
}
if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
@ -419,8 +430,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter)
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
}
})
}

View File

@ -86,10 +86,10 @@ type nftablesManager struct {
mux sync.Mutex
}
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{
return &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
}
// CleanRoutingRules cleans existing nftables rules from the system
@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
}
for _, table := range tables {
if table.Name == "filter" {
n.filterTable = table
continue
}
if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table
@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil
}
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil
}
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {

View File

@ -10,20 +10,23 @@ import (
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
manager, err := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
}
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4

View File

@ -0,0 +1,9 @@
package routemanager
import "github.com/netbirdio/netbird/route"
type serverRouter interface {
updateRoutes(map[string]*route.Route) error
removeFromServerNetwork(*route.Route) error
cleanUp()
}

View File

@ -2,20 +2,11 @@ package routemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@ -13,7 +13,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
type defaultServerRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
@ -21,16 +21,21 @@ type serverRouter struct {
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
firewall, err := newFirewall(ctx)
if err != nil {
return nil, err
}
return &defaultServerRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx),
firewall: firewall,
wgInterface: wgInterface,
}
}, nil
}
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
@ -81,7 +86,7 @@ func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
@ -98,7 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
}
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
@ -115,6 +120,6 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
}
}
func (m *serverRouter) cleanUp() {
func (m *defaultServerRouter) cleanUp() {
m.firewall.CleanRoutingRules()
}

View File

@ -20,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false
}
}

View File

@ -1,84 +0,0 @@
//go:build linux && !android
package ebpf
import (
_ "embed"
"net"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
)
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c --
// EBPF is a wrapper for eBPF program
type EBPF struct {
link link.Link
}
// NewEBPF create new EBPF instance
func NewEBPF() *EBPF {
return &EBPF{}
}
// Load load ebpf program
func (l *EBPF) Load(proxyPort, wgPort int) error {
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
ifce, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// Load pre-compiled programs into the kernel.
objs := bpfObjects{}
err = loadBpfObjects(&objs, nil)
if err != nil {
return err
}
defer func() {
_ = objs.Close()
}()
err = objs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = objs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
defer func() {
_ = objs.NbWgProxySettingsMap.Close()
}()
l.link, err = link.AttachXDP(link.XDPOptions{
Program: objs.NbWgProxy,
Interface: ifce.Index,
})
if err != nil {
return err
}
return err
}
// Free ebpf program
func (l *EBPF) Free() error {
if l.link != nil {
return l.link.Close()
}
return nil
}

View File

@ -1,18 +0,0 @@
//go:build linux
package ebpf
import (
"testing"
)
func Test_newEBPF(t *testing.T) {
ebpf := NewEBPF()
err := ebpf.Load(1234, 51892)
defer func() {
_ = ebpf.Free()
}()
if err != nil {
t.Errorf("%s", err)
}
}

View File

@ -1,90 +0,0 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
if(!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
if(!value) {
return false;
}
wg_port = *value;
return true;
}
SEC("xdp")
int nb_wg_proxy(struct xdp_md *ctx) {
if(proxy_port == 0 || wg_port == 0) {
if(!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != htons(wg_port)){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@ -12,15 +12,15 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpf *ebpf2.EBPF
ebpfManager ebpfMgr.Manager
lastUsedPort uint16
localWGListenPort int
@ -36,7 +36,7 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpf: ebpf2.NewEBPF(),
ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
@ -56,7 +56,7 @@ func (p *WGEBPFProxy) Listen() error {
return err
}
err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil {
return err
}
@ -110,7 +110,7 @@ func (p *WGEBPFProxy) Free() error {
err1 = p.conn.Close()
}
err2 = p.ebpf.Free()
err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil {
err3 = p.rawConn.Close()
}
@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
p.removeTurnConn(endpointPort)
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
return
}
err = p.sendPkg(buf[:n], endpointPort)
@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port)
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}

77
client/netbird.wxs Normal file
View File

@ -0,0 +1,77 @@
<Wix
xmlns="http://wixtoolset.org/schemas/v4/wxs">
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="Wiretrustee UG (haftungsbeschreankt)" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
<MediaTemplate EmbedCab="yes" />
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
<ComponentGroupRef Id="NetbirdFilesComponent" />
</Feature>
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
<StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" />
<ServiceInstall
Id="NetBirdService"
Name="NetBird"
DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network."
Start="auto" Type="ownProcess"
ErrorControl="normal"
Account="LocalSystem"
Vital="yes"
Interactive="no"
Arguments='service run config [CommonAppDataFolder]Netbird\config.json log-level info'
/>
<ServiceControl Id="NetBirdService" Name="NetBird" Start="install" Stop="both" Remove="uninstall" Wait="yes" />
<Environment Id="UpdatePath" Name="PATH" Value="[NetbirdInstallDir]" Part="last" Action="set" System="yes" />
</Component>
</Directory>
</StandardDirectory>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
</ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/>
<CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<InstallExecuteSequence>
<!-- For Uninstallation -->
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
</InstallExecuteSequence>
<!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\netbird.ico" />
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
</Package>
</Wix>

View File

@ -3,10 +3,11 @@ package server
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"sync"
"time"
"github.com/netbirdio/netbird/client/internal/auth"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@ -315,7 +316,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil {
if err == context.Canceled {
return nil, nil
return nil, nil //nolint:nilnil
}
s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now()

View File

@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
// Copy copies a nameserver group object
func (g *NameServerGroup) Copy() *NameServerGroup {
return &NameServerGroup{
nsGroup := &NameServerGroup{
ID: g.ID,
Name: g.Name,
Description: g.Description,
NameServers: g.NameServers,
Groups: g.Groups,
NameServers: make([]NameServer, len(g.NameServers)),
Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled,
Primary: g.Primary,
Domains: g.Domains,
Domains: make([]string, len(g.Domains)),
}
copy(nsGroup.NameServers, g.NameServers)
copy(nsGroup.Groups, g.Groups)
copy(nsGroup.Domains, g.Domains)
return nsGroup
}
// IsEqual compares one nameserver group with the other

2
go.mod
View File

@ -31,7 +31,7 @@ require (
fyne.io/fyne/v2 v2.1.4
github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.6.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1
github.com/getlantern/systray v1.2.1

4
go.sum
View File

@ -131,8 +131,8 @@ github.com/containerd/typeurl v1.0.2/go.mod h1:9trJWW2sRlGub4wZJRTW83VtbOLS6hwcD
github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U=
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk=
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=

View File

@ -112,7 +112,7 @@ func (w *WGIface) Close() error {
return w.tun.Close()
}
// SetFilter sets packet filters for the userspace impelemntation
// SetFilter sets packet filters for the userspace implementation
func (w *WGIface) SetFilter(filter PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()

View File

@ -161,7 +161,7 @@ func getModulePath(name string) (string, error) {
}
if err != nil {
// skip broken files
return nil
return nil //nolint:nilerr
}
if !info.Type().IsRegular() {

View File

@ -146,9 +146,6 @@ func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
}
}
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,

View File

@ -168,4 +168,4 @@ env | grep NETBIRD
envsubst <docker-compose.yml.tmpl >docker-compose.yml
envsubst <management.json.tmpl >management.json
envsubst <turnserver.conf.tmpl >turnserver.conf
envsubst <turnserver.conf.tmpl >turnserver.conf

View File

@ -36,7 +36,7 @@ services:
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
ports:
- 10000:80
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]

View File

@ -487,7 +487,48 @@ renderCaddyfile() {
}
}
(security_headers) {
header * {
# enable HSTS
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#strict-transport-security-hsts
# NOTE: Read carefully how this header works before using it.
# If the HSTS header is misconfigured or if there is a problem with
# the SSL/TLS certificate being used, legitimate users might be unable
# to access the website. For example, if the HSTS header is set to a
# very long duration and the SSL/TLS certificate expires or is revoked,
# legitimate users might be unable to access the website until
# the HSTS header duration has expired.
# The recommended value for the max-age is 2 year (63072000 seconds).
# But we are using 1 hour (3600 seconds) for testing purposes
# and ensure that the website is working properly before setting
# to two years.
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
# disable clients from sniffing the media type
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-content-type-options
X-Content-Type-Options "nosniff"
# clickjacking protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options
X-Frame-Options "DENY"
# xss protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection
X-XSS-Protection "1; mode=block"
# Remove -Server header, which is an information leak
# Remove Caddy from Headers
-Server
# keep referrer data off of HTTP connections
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#referrer-policy
Referrer-Policy strict-origin-when-cross-origin
}
}
:80${CADDY_SECURE_DOMAIN} {
import security_headers
# Signal
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management

View File

@ -21,4 +21,5 @@ NETBIRD_AUTH_USER_ID_CLAIM="email"
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_SIGNAL_PORT=12345

View File

@ -61,12 +61,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
eventStore, false)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -19,28 +19,26 @@ import (
"github.com/google/uuid"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
@ -72,10 +70,15 @@ var (
Use: "management",
Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
var err error
config, err = loadMgmtConfig(mgmtConfig)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
@ -104,13 +107,7 @@ var (
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
err = handleRebrand(cmd)
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
}
@ -146,12 +143,22 @@ var (
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, err := sqlite.NewSQLiteStore(config.Datadir)
eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return err
return fmt.Errorf("failed to initialize database: %s", err)
}
if key != "" {
log.Debugf("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore)
dnsDomain, eventStore, userDeleteFromIDPEnabled)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}
@ -202,8 +209,11 @@ var (
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers()
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
@ -272,6 +282,7 @@ var (
SetupCloseHandler()
<-stopCh
ephemeralManager.Stop()
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
@ -287,6 +298,20 @@ var (
}
)
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
var err error
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = sqlite.GenerateKey()
if err != nil {
return nil, "", err
}
}
store, err := sqlite.NewSQLiteStore(dataDir, key)
return store, key, err
}
func notifyStop(msg string) {
select {
case stopCh <- 1:
@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err
}
func updateMgmtConfig(path string, config *server.Config) error {
return util.WriteJson(path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct {
Issuer string `json:"issuer"`

View File

@ -24,6 +24,7 @@ var (
disableMetrics bool
disableSingleAccMode bool
idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@ -56,6 +57,7 @@ func init() {
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@ -49,7 +49,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
@ -80,7 +80,6 @@ type AccountManager interface {
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error
@ -93,13 +92,11 @@ type AccountManager interface {
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
@ -133,6 +130,9 @@ type DefaultAccountManager struct {
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@ -189,14 +189,15 @@ type Account struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
LastLogin time.Time `json:"last_login"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@ -629,8 +630,8 @@ func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID]
}
// AddJWTGroups to account and to user autoassigned groups
func (a *Account) AddJWTGroups(userID string, groups []string) bool {
// SetJWTGroups to account and to user autoassigned groups
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
user, ok := a.Users[userID]
if !ok {
return false
@ -641,13 +642,21 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool {
existedGroupsByName[group.Name] = group
}
autoGroups := make(map[string]struct{})
for _, groupID := range user.AutoGroups {
autoGroups[groupID] = struct{}{}
// remove JWT groups from the autogroups, to sync them again
removed := 0
jwtAutoGroups := make(map[string]struct{})
for i, id := range user.AutoGroups {
if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT {
jwtAutoGroups[group.Name] = struct{}{}
user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...)
removed++
}
}
// create JWT groups if they doesn't exist
// and all of them to the autogroups
var modified bool
for _, name := range groups {
for _, name := range groupsNames {
group, ok := existedGroupsByName[name]
if !ok {
group = &Group{
@ -656,16 +665,22 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool {
Issued: GroupIssuedJWT,
}
a.Groups[group.ID] = group
modified = true
}
if _, ok := autoGroups[group.ID]; !ok {
if group.Issued == GroupIssuedJWT {
user.AutoGroups = append(user.AutoGroups, group.ID)
// only JWT groups will be synced
if group.Issued == GroupIssuedJWT {
user.AutoGroups = append(user.AutoGroups, group.ID)
if _, ok := jwtAutoGroups[name]; !ok {
modified = true
}
delete(jwtAutoGroups, name)
}
}
// if not empty it means we removed some groups
if len(jwtAutoGroups) > 0 {
modified = true
}
return modified
}
@ -723,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
@ -859,32 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
return account.GetNextPeerExpiration()
}
expiredPeers := account.GetExpiredPeers()
var peerIDs []string
for _, peer := range account.GetExpiredPeers() {
if peer.Status.LoginExpired {
continue
}
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
}
return account.GetNextPeerExpiration()
}
}
@ -1006,7 +1009,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
}
}
return nil, nil
return nil, nil //nolint:nilnil
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
@ -1029,7 +1032,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
}
}
return nil, nil
return nil, nil //nolint:nilnil
}
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
@ -1357,23 +1360,58 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groups []string
var groupsNames []string
for _, item := range slice {
if g, ok := item.(string); ok {
groups = append(groups, g)
groupsNames = append(groupsNames, g)
} else {
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account
if account.AddJWTGroups(claims.UserId, groups) {
if account.SetJWTGroups(claims.UserId, groupsNames) {
if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil {
account.UserGroupsAddToPeers(claims.UserId, append(user.AutoGroups, groups...)...)
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
} else {
if err := am.updateAccountPeers(account); err != nil {
log.Errorf("failed updating account peers while updating user %s", account.Id)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
}
}
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
}
}
} else {
@ -1420,7 +1458,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory {
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil
}
}
@ -1554,19 +1592,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
}
return acc
}
func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{})
for _, item := range toRemove {
toRemoveMap[item] = struct{}{}
}
var resultList []string
for _, item := range inputList {
_, ok := toRemoveMap[item]
if !ok {
resultList = append(resultList, item)
}
}
return resultList
}

View File

@ -3,6 +3,7 @@ package server
import (
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
"fmt"
"net"
"reflect"
@ -11,6 +12,7 @@ import (
"time"
"github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route"
@ -781,11 +783,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
if err != nil {
return
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -928,11 +926,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
if err != nil {
return
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1112,11 +1106,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
if err != nil {
return
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1348,6 +1338,11 @@ func TestAccount_Copy(t *testing.T) {
Peers: map[string]*Peer{
"peer1": {
Key: "key1",
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
},
},
Users: map[string]*User{
@ -1370,28 +1365,36 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*Group{
"group1": {
ID: "group1",
ID: "group1",
Peers: []string{"peer1"},
},
},
Rules: map[string]*Rule{
"rule1": {
ID: "rule1",
ID: "rule1",
Destination: []string{},
Source: []string{},
},
},
Policies: []*Policy{
{
ID: "policy1",
Enabled: true,
Rules: make([]*PolicyRule, 0),
},
},
Routes: map[string]*route.Route{
"route1": {
ID: "route1",
ID: "route1",
Groups: []string{"group1"},
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": {
ID: "nsGroup1",
ID: "nsGroup1",
Domains: []string{},
Groups: []string{},
NameServers: []nbdns.NameServer{},
},
},
DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}},
@ -1402,10 +1405,20 @@ func TestAccount_Copy(t *testing.T) {
t.Fatal(err)
}
accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected")
accBytes, err := json.Marshal(account)
if err != nil {
t.Fatal(err)
}
account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change
accCopyBytes, err := json.Marshal(accountCopy)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected")
}
// hasNilField validates pointers, maps and slices if they are nil
// TODO: make it check nested fields too
func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x)
rv = rv.Elem()
@ -1930,7 +1943,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}
}
func TestAccount_AddJWTGroups(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) {
// create a new account
account := &Account{
Peers: map[string]*Peer{
@ -1951,13 +1964,13 @@ func TestAccount_AddJWTGroups(t *testing.T) {
}
t.Run("api group already exists", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group1"})
updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated")
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
})
t.Run("add jwt group", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group1", "group2"})
updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 2, "new group should be added")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added")
@ -1965,13 +1978,13 @@ func TestAccount_AddJWTGroups(t *testing.T) {
})
t.Run("existed group not update", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group2"})
updated := account.SetJWTGroups("user1", []string{"group2"})
assert.False(t, updated, "account should not be updated")
assert.Len(t, account.Groups, 2, "groups count should not be changed")
})
t.Run("add new group", func(t *testing.T) {
updated := account.AddJWTGroups("user2", []string{"group1", "group3"})
updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 3, "new group should be added")
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
@ -2050,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
}
func createStore(t *testing.T) (Store, error) {

View File

@ -1,5 +1,14 @@
package activity
// Activity that triggered an Event
type Activity int
// Code is an activity string representation
type Code struct {
message string
code string
}
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota
@ -95,316 +104,85 @@ const (
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
// UserDeleted indicates that a user deleted another user
UserDeleted
// GroupDeleted indicates that a user deleted group
GroupDeleted
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
)
const (
// PeerAddedByUserMessage is a human-readable text message of the PeerAddedByUser activity
PeerAddedByUserMessage string = "Peer added"
// PeerAddedWithSetupKeyMessage is a human-readable text message of the PeerAddedWithSetupKey activity
PeerAddedWithSetupKeyMessage = PeerAddedByUserMessage
// UserJoinedMessage is a human-readable text message of the UserJoined activity
UserJoinedMessage string = "User joined"
// UserInvitedMessage is a human-readable text message of the UserInvited activity
UserInvitedMessage string = "User invited"
// AccountCreatedMessage is a human-readable text message of the AccountCreated activity
AccountCreatedMessage string = "Account created"
// PeerRemovedByUserMessage is a human-readable text message of the PeerRemovedByUser activity
PeerRemovedByUserMessage string = "Peer deleted"
// RuleAddedMessage is a human-readable text message of the RuleAdded activity
RuleAddedMessage string = "Rule added"
// RuleRemovedMessage is a human-readable text message of the RuleRemoved activity
RuleRemovedMessage string = "Rule deleted"
// RuleUpdatedMessage is a human-readable text message of the RuleRemoved activity
RuleUpdatedMessage string = "Rule updated"
// PolicyAddedMessage is a human-readable text message of the PolicyAdded activity
PolicyAddedMessage string = "Policy added"
// PolicyRemovedMessage is a human-readable text message of the PolicyRemoved activity
PolicyRemovedMessage string = "Policy deleted"
// PolicyUpdatedMessage is a human-readable text message of the PolicyRemoved activity
PolicyUpdatedMessage string = "Policy updated"
// SetupKeyCreatedMessage is a human-readable text message of the SetupKeyCreated activity
SetupKeyCreatedMessage string = "Setup key created"
// SetupKeyUpdatedMessage is a human-readable text message of the SetupKeyUpdated activity
SetupKeyUpdatedMessage string = "Setup key updated"
// SetupKeyRevokedMessage is a human-readable text message of the SetupKeyRevoked activity
SetupKeyRevokedMessage string = "Setup key revoked"
// SetupKeyOverusedMessage is a human-readable text message of the SetupKeyOverused activity
SetupKeyOverusedMessage string = "Setup key overused"
// GroupCreatedMessage is a human-readable text message of the GroupCreated activity
GroupCreatedMessage string = "Group created"
// GroupUpdatedMessage is a human-readable text message of the GroupUpdated activity
GroupUpdatedMessage string = "Group updated"
// GroupAddedToPeerMessage is a human-readable text message of the GroupAddedToPeer activity
GroupAddedToPeerMessage string = "Group added to peer"
// GroupRemovedFromPeerMessage is a human-readable text message of the GroupRemovedFromPeer activity
GroupRemovedFromPeerMessage string = "Group removed from peer"
// GroupAddedToUserMessage is a human-readable text message of the GroupAddedToUser activity
GroupAddedToUserMessage string = "Group added to user"
// GroupRemovedFromUserMessage is a human-readable text message of the GroupRemovedFromUser activity
GroupRemovedFromUserMessage string = "Group removed from user"
// UserRoleUpdatedMessage is a human-readable text message of the UserRoleUpdatedMessage activity
UserRoleUpdatedMessage string = "User role updated"
// GroupAddedToSetupKeyMessage is a human-readable text message of the GroupAddedToSetupKey activity
GroupAddedToSetupKeyMessage string = "Group added to setup key"
// GroupRemovedFromSetupKeyMessage is a human-readable text message of the GroupRemovedFromSetupKey activity
GroupRemovedFromSetupKeyMessage string = "Group removed from user setup key"
// GroupAddedToDisabledManagementGroupsMessage is a human-readable text message of the GroupAddedToDisabledManagementGroups activity
GroupAddedToDisabledManagementGroupsMessage string = "Group added to disabled management DNS setting"
// GroupRemovedFromDisabledManagementGroupsMessage is a human-readable text message of the GroupRemovedFromDisabledManagementGroups activity
GroupRemovedFromDisabledManagementGroupsMessage string = "Group removed from disabled management DNS setting"
// RouteCreatedMessage is a human-readable text message of the RouteCreated activity
RouteCreatedMessage string = "Route created"
// RouteRemovedMessage is a human-readable text message of the RouteRemoved activity
RouteRemovedMessage string = "Route deleted"
// RouteUpdatedMessage is a human-readable text message of the RouteUpdated activity
RouteUpdatedMessage string = "Route updated"
// PeerSSHEnabledMessage is a human-readable text message of the PeerSSHEnabled activity
PeerSSHEnabledMessage string = "Peer SSH server enabled"
// PeerSSHDisabledMessage is a human-readable text message of the PeerSSHDisabled activity
PeerSSHDisabledMessage string = "Peer SSH server disabled"
// PeerRenamedMessage is a human-readable text message of the PeerRenamed activity
PeerRenamedMessage string = "Peer renamed"
// PeerLoginExpirationDisabledMessage is a human-readable text message of the PeerLoginExpirationDisabled activity
PeerLoginExpirationDisabledMessage string = "Peer login expiration disabled"
// PeerLoginExpirationEnabledMessage is a human-readable text message of the PeerLoginExpirationEnabled activity
PeerLoginExpirationEnabledMessage string = "Peer login expiration enabled"
// NameserverGroupCreatedMessage is a human-readable text message of the NameserverGroupCreated activity
NameserverGroupCreatedMessage string = "Nameserver group created"
// NameserverGroupDeletedMessage is a human-readable text message of the NameserverGroupDeleted activity
NameserverGroupDeletedMessage string = "Nameserver group deleted"
// NameserverGroupUpdatedMessage is a human-readable text message of the NameserverGroupUpdated activity
NameserverGroupUpdatedMessage string = "Nameserver group updated"
// AccountPeerLoginExpirationEnabledMessage is a human-readable text message of the AccountPeerLoginExpirationEnabled activity
AccountPeerLoginExpirationEnabledMessage string = "Peer login expiration enabled for the account"
// AccountPeerLoginExpirationDisabledMessage is a human-readable text message of the AccountPeerLoginExpirationDisabled activity
AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account"
// AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity
AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated"
// PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity
PersonalAccessTokenCreatedMessage string = "Personal access token created"
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked"
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
GroupDeletedMessage string = "Group deleted"
)
// Activity that triggered an Event
type Activity int
// Message returns a string representation of an activity
func (a Activity) Message() string {
switch a {
case PeerAddedByUser:
return PeerAddedByUserMessage
case PeerRemovedByUser:
return PeerRemovedByUserMessage
case PeerAddedWithSetupKey:
return PeerAddedWithSetupKeyMessage
case UserJoined:
return UserJoinedMessage
case UserInvited:
return UserInvitedMessage
case AccountCreated:
return AccountCreatedMessage
case RuleAdded:
return RuleAddedMessage
case RuleRemoved:
return RuleRemovedMessage
case RuleUpdated:
return RuleUpdatedMessage
case PolicyAdded:
return PolicyAddedMessage
case PolicyRemoved:
return PolicyRemovedMessage
case PolicyUpdated:
return PolicyUpdatedMessage
case SetupKeyCreated:
return SetupKeyCreatedMessage
case SetupKeyUpdated:
return SetupKeyUpdatedMessage
case SetupKeyRevoked:
return SetupKeyRevokedMessage
case SetupKeyOverused:
return SetupKeyOverusedMessage
case GroupCreated:
return GroupCreatedMessage
case GroupUpdated:
return GroupUpdatedMessage
case GroupAddedToPeer:
return GroupAddedToPeerMessage
case GroupRemovedFromPeer:
return GroupRemovedFromPeerMessage
case GroupRemovedFromUser:
return GroupRemovedFromUserMessage
case GroupAddedToUser:
return GroupAddedToUserMessage
case UserRoleUpdated:
return UserRoleUpdatedMessage
case GroupAddedToSetupKey:
return GroupAddedToSetupKeyMessage
case GroupRemovedFromSetupKey:
return GroupRemovedFromSetupKeyMessage
case GroupAddedToDisabledManagementGroups:
return GroupAddedToDisabledManagementGroupsMessage
case GroupRemovedFromDisabledManagementGroups:
return GroupRemovedFromDisabledManagementGroupsMessage
case RouteCreated:
return RouteCreatedMessage
case RouteRemoved:
return RouteRemovedMessage
case RouteUpdated:
return RouteUpdatedMessage
case PeerSSHEnabled:
return PeerSSHEnabledMessage
case PeerSSHDisabled:
return PeerSSHDisabledMessage
case PeerLoginExpirationEnabled:
return PeerLoginExpirationEnabledMessage
case PeerLoginExpirationDisabled:
return PeerLoginExpirationDisabledMessage
case PeerRenamed:
return PeerRenamedMessage
case NameserverGroupCreated:
return NameserverGroupCreatedMessage
case NameserverGroupDeleted:
return NameserverGroupDeletedMessage
case NameserverGroupUpdated:
return NameserverGroupUpdatedMessage
case AccountPeerLoginExpirationEnabled:
return AccountPeerLoginExpirationEnabledMessage
case AccountPeerLoginExpirationDisabled:
return AccountPeerLoginExpirationDisabledMessage
case AccountPeerLoginExpirationDurationUpdated:
return AccountPeerLoginExpirationDurationUpdatedMessage
case PersonalAccessTokenCreated:
return PersonalAccessTokenCreatedMessage
case PersonalAccessTokenDeleted:
return PersonalAccessTokenDeletedMessage
case ServiceUserCreated:
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
case UserBlocked:
return UserBlockedMessage
case UserUnblocked:
return UserUnblockedMessage
case GroupDeleted:
return GroupDeletedMessage
default:
return "UNKNOWN_ACTIVITY"
}
var activityMap = map[Activity]Code{
PeerAddedByUser: {"Peer added", "user.peer.add"},
PeerAddedWithSetupKey: {"Peer added", "setupkey.peer.add"},
UserJoined: {"User joined", "user.join"},
UserInvited: {"User invited", "user.invite"},
AccountCreated: {"Account created", "account.create"},
PeerRemovedByUser: {"Peer deleted", "user.peer.delete"},
RuleAdded: {"Rule added", "rule.add"},
RuleUpdated: {"Rule updated", "rule.update"},
RuleRemoved: {"Rule deleted", "rule.delete"},
PolicyAdded: {"Policy added", "policy.add"},
PolicyUpdated: {"Policy updated", "policy.update"},
PolicyRemoved: {"Policy deleted", "policy.delete"},
SetupKeyCreated: {"Setup key created", "setupkey.add"},
SetupKeyUpdated: {"Setup key updated", "setupkey.update"},
SetupKeyRevoked: {"Setup key revoked", "setupkey.revoke"},
SetupKeyOverused: {"Setup key overused", "setupkey.overuse"},
GroupCreated: {"Group created", "group.add"},
GroupUpdated: {"Group updated", "group.update"},
GroupAddedToPeer: {"Group added to peer", "peer.group.add"},
GroupRemovedFromPeer: {"Group removed from peer", "peer.group.delete"},
GroupAddedToUser: {"Group added to user", "user.group.add"},
GroupRemovedFromUser: {"Group removed from user", "user.group.delete"},
UserRoleUpdated: {"User role updated", "user.role.update"},
GroupAddedToSetupKey: {"Group added to setup key", "setupkey.group.add"},
GroupRemovedFromSetupKey: {"Group removed from user setup key", "setupkey.group.delete"},
GroupAddedToDisabledManagementGroups: {"Group added to disabled management DNS setting", "dns.setting.disabled.management.group.add"},
GroupRemovedFromDisabledManagementGroups: {"Group removed from disabled management DNS setting", "dns.setting.disabled.management.group.delete"},
RouteCreated: {"Route created", "route.add"},
RouteRemoved: {"Route deleted", "route.delete"},
RouteUpdated: {"Route updated", "route.update"},
PeerSSHEnabled: {"Peer SSH server enabled", "peer.ssh.enable"},
PeerSSHDisabled: {"Peer SSH server disabled", "peer.ssh.disable"},
PeerRenamed: {"Peer renamed", "peer.rename"},
PeerLoginExpirationEnabled: {"Peer login expiration enabled", "peer.login.expiration.enable"},
PeerLoginExpirationDisabled: {"Peer login expiration disabled", "peer.login.expiration.disable"},
NameserverGroupCreated: {"Nameserver group created", "nameserver.group.add"},
NameserverGroupDeleted: {"Nameserver group deleted", "nameserver.group.delete"},
NameserverGroupUpdated: {"Nameserver group updated", "nameserver.group.update"},
AccountPeerLoginExpirationDurationUpdated: {"Account peer login expiration duration updated", "account.setting.peer.login.expiration.update"},
AccountPeerLoginExpirationEnabled: {"Account peer login expiration enabled", "account.setting.peer.login.expiration.enable"},
AccountPeerLoginExpirationDisabled: {"Account peer login expiration disabled", "account.setting.peer.login.expiration.disable"},
PersonalAccessTokenCreated: {"Personal access token created", "personal.access.token.create"},
PersonalAccessTokenDeleted: {"Personal access token deleted", "personal.access.token.delete"},
ServiceUserCreated: {"Service user created", "service.user.create"},
ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
UserBlocked: {"User blocked", "user.block"},
UserUnblocked: {"User unblocked", "user.unblock"},
UserDeleted: {"User deleted", "user.delete"},
GroupDeleted: {"Group deleted", "group.delete"},
UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
PeerLoginExpired: {"Peer login expired", "peer.login.expire"},
DashboardLogin: {"Dashboard login", "dashboard.login"},
}
// StringCode returns a string code of the activity
func (a Activity) StringCode() string {
switch a {
case PeerAddedByUser:
return "user.peer.add"
case PeerRemovedByUser:
return "user.peer.delete"
case PeerAddedWithSetupKey:
return "setupkey.peer.add"
case UserJoined:
return "user.join"
case UserInvited:
return "user.invite"
case UserBlocked:
return "user.block"
case UserUnblocked:
return "user.unblock"
case AccountCreated:
return "account.create"
case RuleAdded:
return "rule.add"
case RuleRemoved:
return "rule.delete"
case RuleUpdated:
return "rule.update"
case PolicyAdded:
return "policy.add"
case PolicyRemoved:
return "policy.delete"
case PolicyUpdated:
return "policy.update"
case SetupKeyCreated:
return "setupkey.add"
case SetupKeyRevoked:
return "setupkey.revoke"
case SetupKeyOverused:
return "setupkey.overuse"
case SetupKeyUpdated:
return "setupkey.update"
case GroupCreated:
return "group.add"
case GroupUpdated:
return "group.update"
case GroupDeleted:
return "group.delete"
case GroupRemovedFromPeer:
return "peer.group.delete"
case GroupAddedToPeer:
return "peer.group.add"
case GroupAddedToUser:
return "user.group.add"
case GroupRemovedFromUser:
return "user.group.delete"
case UserRoleUpdated:
return "user.role.update"
case GroupAddedToSetupKey:
return "setupkey.group.add"
case GroupRemovedFromSetupKey:
return "setupkey.group.delete"
case GroupAddedToDisabledManagementGroups:
return "dns.setting.disabled.management.group.add"
case GroupRemovedFromDisabledManagementGroups:
return "dns.setting.disabled.management.group.delete"
case RouteCreated:
return "route.add"
case RouteRemoved:
return "route.delete"
case RouteUpdated:
return "route.update"
case PeerRenamed:
return "peer.rename"
case PeerSSHEnabled:
return "peer.ssh.enable"
case PeerSSHDisabled:
return "peer.ssh.disable"
case PeerLoginExpirationDisabled:
return "peer.login.expiration.disable"
case PeerLoginExpirationEnabled:
return "peer.login.expiration.enable"
case NameserverGroupCreated:
return "nameserver.group.add"
case NameserverGroupDeleted:
return "nameserver.group.delete"
case NameserverGroupUpdated:
return "nameserver.group.update"
case AccountPeerLoginExpirationDurationUpdated:
return "account.setting.peer.login.expiration.update"
case AccountPeerLoginExpirationEnabled:
return "account.setting.peer.login.expiration.enable"
case AccountPeerLoginExpirationDisabled:
return "account.setting.peer.login.expiration.disable"
case PersonalAccessTokenCreated:
return "personal.access.token.create"
case PersonalAccessTokenDeleted:
return "personal.access.token.delete"
case ServiceUserCreated:
return "service.user.create"
case ServiceUserDeleted:
return "service.user.delete"
default:
return "UNKNOWN_ACTIVITY"
if code, ok := activityMap[a]; ok {
return code.code
}
return "UNKNOWN_ACTIVITY"
}
// Message returns a string representation of an activity
func (a Activity) Message() string {
if code, ok := activityMap[a]; ok {
return code.message
}
return "UNKNOWN_ACTIVITY"
}

View File

@ -4,6 +4,10 @@ import (
"time"
)
const (
SystemInitiator = "sys"
)
// Event represents a network/system activity event.
type Event struct {
// Timestamp of the event
@ -14,10 +18,13 @@ type Event struct {
ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
InitiatorID string
// InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only
InitiatorEmail string
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
TargetID string
// AccountID is the ID of an account where the event happened
AccountID string
// Meta of the event, e.g. deleted peer information like name, IP, etc
Meta map[string]any
}
@ -31,12 +38,13 @@ func (e *Event) Copy() *Event {
}
return &Event{
Timestamp: e.Timestamp,
Activity: e.Activity,
ID: e.ID,
InitiatorID: e.InitiatorID,
TargetID: e.TargetID,
AccountID: e.AccountID,
Meta: meta,
Timestamp: e.Timestamp,
Activity: e.Activity,
ID: e.ID,
InitiatorID: e.InitiatorID,
InitiatorEmail: e.InitiatorEmail,
TargetID: e.TargetID,
AccountID: e.AccountID,
Meta: meta,
}
}

View File

@ -0,0 +1,81 @@
package sqlite
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type EmailEncrypt struct {
block cipher.Block
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewEmailEncrypt(key string) (*EmailEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
ec := &EmailEncrypt{
block: block,
}
return ec, nil
}
func (ec *EmailEncrypt) Encrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *EmailEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
}
return src[:srcLen-paddingLen], nil
}

View File

@ -0,0 +1,63 @@
package sqlite
import (
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewEmailEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, err := ee.Decrypt(encrypted)
if err == nil || res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}

View File

@ -4,12 +4,13 @@ import (
"database/sql"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/activity"
// sqlite driver
_ "github.com/mattn/go-sqlite3"
"path/filepath"
"time"
_ "github.com/mattn/go-sqlite3" // sqlite driver
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
)
const (
@ -24,44 +25,106 @@ const (
"meta TEXT," +
" target_id TEXT);"
selectStatement = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
" FROM events WHERE account_id = ? ORDER BY timestamp %s LIMIT ? OFFSET ?;"
insertStatement = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
"VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)`
)
// Store is the implementation of the activity.Store interface backed by SQLite
type Store struct {
db *sql.DB
db *sql.DB
emailEncrypt *EmailEncrypt
insertStatement *sql.Stmt
selectAscStatement *sql.Stmt
selectDescStatement *sql.Stmt
deleteUserStmt *sql.Stmt
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string) (*Store, error) {
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
if err != nil {
return nil, err
}
crypt, err := NewEmailEncrypt(encryptionKey)
if err != nil {
return nil, err
}
_, err = db.Exec(createTableQuery)
if err != nil {
return nil, err
}
return &Store{db: db}, nil
_, err = db.Exec(creatTableAccountEmailQuery)
if err != nil {
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
return nil, err
}
s := &Store{
db: db,
emailEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
}
func processResult(result *sql.Rows) ([]*activity.Event, error) {
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
for result.Next() {
var id int64
var operation activity.Activity
var timestamp time.Time
var initiator string
var initiatorEmail *string
var target string
var targetEmail *string
var account string
var jsonMeta string
err := result.Scan(&id, &operation, &timestamp, &initiator, &target, &account, &jsonMeta)
err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta)
if err != nil {
return nil, err
}
@ -74,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
}
}
events = append(events, &activity.Event{
if targetEmail != nil {
email, err := store.emailEncrypt.Decrypt(*targetEmail)
if err != nil {
log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = ""
} else {
meta["email"] = email
}
}
event := &activity.Event{
Timestamp: timestamp,
Activity: operation,
ID: uint64(id),
@ -82,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
TargetID: target,
AccountID: account,
Meta: meta,
})
}
if initiatorEmail != nil {
email, err := store.emailEncrypt.Decrypt(*initiatorEmail)
if err != nil {
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else {
event.InitiatorEmail = email
}
}
events = append(events, event)
}
return events, nil
@ -90,13 +174,9 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
order := "DESC"
stmt := store.selectDescStatement
if !descending {
order = "ASC"
}
stmt, err := store.db.Prepare(fmt.Sprintf(selectStatement, order))
if err != nil {
return nil, err
stmt = store.selectAscStatement
}
result, err := stmt.Query(accountID, limit, offset)
@ -105,19 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
}
defer result.Close() //nolint
return processResult(result)
return store.processResult(result)
}
// Save an event in the SQLite events table
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
stmt, err := store.db.Prepare(insertStatement)
var jsonMeta string
meta, err := store.saveDeletedUserEmailInEncrypted(event)
if err != nil {
return nil, err
}
var jsonMeta string
if event.Meta != nil {
if meta != nil {
metaBytes, err := json.Marshal(event.Meta)
if err != nil {
return nil, err
@ -125,7 +204,7 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
jsonMeta = string(metaBytes)
}
result, err := stmt.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta)
result, err := store.insertStatement.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta)
if err != nil {
return nil, err
}
@ -140,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
return eventCopy, nil
}
// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from
// meta map
func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"]
if !ok {
return event.Meta, nil
}
delete(event.Meta, "email")
encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email))
_, err := store.deleteUserStmt.Exec(event.TargetID, encrypted)
if err != nil {
return nil, err
}
if len(event.Meta) == 1 {
return nil, nil // nolint
}
delete(event.Meta, "email")
return event.Meta, nil
}
// Close the Store
func (store *Store) Close() error {
if store.db != nil {

Some files were not shown because too many files have changed in this diff Show More