Merge branch 'main' into feature/port-forwarding
160
.github/workflows/golang-test-linux.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test Code Linux
|
name: Linux
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -12,11 +12,21 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-cache:
|
build-cache:
|
||||||
|
name: "Build Cache"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
outputs:
|
||||||
|
management: ${{ steps.filter.outputs.management }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
management:
|
||||||
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@ -39,7 +49,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
@ -89,6 +98,7 @@ jobs:
|
|||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@ -134,9 +144,116 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_relay:
|
||||||
|
name: "Relay / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
|
test_signal:
|
||||||
|
name: "Signal / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
test_management:
|
test_management:
|
||||||
|
name: "Management / Unit"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@ -194,10 +311,17 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
go test -tags=devcert \
|
||||||
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
|
-timeout 10m ./management/...
|
||||||
|
|
||||||
benchmark:
|
benchmark:
|
||||||
|
name: "Management / Benchmark"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -254,10 +378,17 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./...
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
|
name: "Management / Benchmark (API)"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -314,10 +445,19 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=benchmark \
|
||||||
|
-run=^$ \
|
||||||
|
-bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
|
name: "Management / Integration"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@ -363,9 +503,15 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=integration \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 10m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
|
@ -103,7 +103,7 @@ linters:
|
|||||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
issues:
|
issues:
|
||||||
# Maximum count of issues with the same text.
|
# Maximum count of issues with the same text.
|
||||||
|
@ -53,7 +53,7 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
@ -70,7 +70,7 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
@ -9,6 +9,7 @@ USER netbird:netbird
|
|||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENV NB_USE_NETSTACK_MODE=true
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||||
ENV NB_CONFIG=config.json
|
ENV NB_CONFIG=config.json
|
||||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
ENV NB_DISABLE_DNS=true
|
ENV NB_DISABLE_DNS=true
|
||||||
|
@ -39,7 +39,6 @@ type peerStateDetailOutput struct {
|
|||||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
Networks []string `json:"networks" yaml:"networks"`
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,10 +97,11 @@ type statusOutputOverview struct {
|
|||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
Networks []string `json:"networks" yaml:"networks"`
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
|
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
|
||||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||||
|
Events []systemEventOutput `json:"events" yaml:"events"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -285,10 +285,11 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
|||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
|
||||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||||
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
|
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||||
|
Events: mapEvents(pbFullStatus.GetEvents()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if anonymizeFlag {
|
if anonymizeFlag {
|
||||||
@ -395,7 +396,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
TransferSent: transferSent,
|
TransferSent: transferSent,
|
||||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||||
Routes: pbPeerState.GetNetworks(),
|
|
||||||
Networks: pbPeerState.GetNetworks(),
|
Networks: pbPeerState.GetNetworks(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -561,7 +561,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
"NetBird IP: %s\n"+
|
"NetBird IP: %s\n"+
|
||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Routes: %s\n"+
|
|
||||||
"Networks: %s\n"+
|
"Networks: %s\n"+
|
||||||
"Forwarding rules: %d\n"+
|
"Forwarding rules: %d\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
@ -577,7 +576,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
networks,
|
networks,
|
||||||
networks,
|
|
||||||
overview.NumberOfForwardingRules,
|
overview.NumberOfForwardingRules,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
@ -586,13 +584,17 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
|
|
||||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||||
|
parsedEventsString := parseEvents(overview.Events)
|
||||||
summary := parseGeneralSummary(overview, true, true, true)
|
summary := parseGeneralSummary(overview, true, true, true)
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"Peers detail:"+
|
"Peers detail:"+
|
||||||
|
"%s\n"+
|
||||||
|
"Events:"+
|
||||||
"%s\n"+
|
"%s\n"+
|
||||||
"%s",
|
"%s",
|
||||||
parsedPeersString,
|
parsedPeersString,
|
||||||
|
parsedEventsString,
|
||||||
summary,
|
summary,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -661,7 +663,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
" Last WireGuard handshake: %s\n"+
|
" Last WireGuard handshake: %s\n"+
|
||||||
" Transfer status (received/sent) %s/%s\n"+
|
" Transfer status (received/sent) %s/%s\n"+
|
||||||
" Quantum resistance: %s\n"+
|
" Quantum resistance: %s\n"+
|
||||||
" Routes: %s\n"+
|
|
||||||
" Networks: %s\n"+
|
" Networks: %s\n"+
|
||||||
" Latency: %s\n",
|
" Latency: %s\n",
|
||||||
peerState.FQDN,
|
peerState.FQDN,
|
||||||
@ -680,7 +681,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
toIEC(peerState.TransferSent),
|
toIEC(peerState.TransferSent),
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
networks,
|
networks,
|
||||||
networks,
|
|
||||||
peerState.Latency.String(),
|
peerState.Latency.String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -829,14 +829,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|||||||
for i, route := range peer.Networks {
|
for i, route := range peer.Networks {
|
||||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||||
@ -874,9 +866,14 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
|||||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, route := range overview.Routes {
|
|
||||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||||
|
|
||||||
|
for i, event := range overview.Events {
|
||||||
|
overview.Events[i].Message = a.AnonymizeString(event.Message)
|
||||||
|
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
|
||||||
|
|
||||||
|
for k, v := range event.Metadata {
|
||||||
|
event.Metadata[k] = a.AnonymizeString(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
69
client/cmd/status_event.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type systemEventOutput struct {
|
||||||
|
ID string `json:"id" yaml:"id"`
|
||||||
|
Severity string `json:"severity" yaml:"severity"`
|
||||||
|
Category string `json:"category" yaml:"category"`
|
||||||
|
Message string `json:"message" yaml:"message"`
|
||||||
|
UserMessage string `json:"userMessage" yaml:"userMessage"`
|
||||||
|
Timestamp time.Time `json:"timestamp" yaml:"timestamp"`
|
||||||
|
Metadata map[string]string `json:"metadata" yaml:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapEvents(protoEvents []*proto.SystemEvent) []systemEventOutput {
|
||||||
|
events := make([]systemEventOutput, len(protoEvents))
|
||||||
|
for i, event := range protoEvents {
|
||||||
|
events[i] = systemEventOutput{
|
||||||
|
ID: event.GetId(),
|
||||||
|
Severity: event.GetSeverity().String(),
|
||||||
|
Category: event.GetCategory().String(),
|
||||||
|
Message: event.GetMessage(),
|
||||||
|
UserMessage: event.GetUserMessage(),
|
||||||
|
Timestamp: event.GetTimestamp().AsTime(),
|
||||||
|
Metadata: event.GetMetadata(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEvents(events []systemEventOutput) string {
|
||||||
|
if len(events) == 0 {
|
||||||
|
return " No events recorded"
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventsString strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
timeStr := timeAgo(event.Timestamp)
|
||||||
|
|
||||||
|
metadataStr := ""
|
||||||
|
if len(event.Metadata) > 0 {
|
||||||
|
pairs := make([]string, 0, len(event.Metadata))
|
||||||
|
for k, v := range event.Metadata {
|
||||||
|
pairs = append(pairs, fmt.Sprintf("%s: %s", k, v))
|
||||||
|
}
|
||||||
|
sort.Strings(pairs)
|
||||||
|
metadataStr = fmt.Sprintf("\n Metadata: %s", strings.Join(pairs, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsString.WriteString(fmt.Sprintf("\n [%s] %s (%s)"+
|
||||||
|
"\n Message: %s"+
|
||||||
|
"\n Time: %s%s",
|
||||||
|
event.Severity,
|
||||||
|
event.Category,
|
||||||
|
event.ID,
|
||||||
|
event.Message,
|
||||||
|
timeStr,
|
||||||
|
metadataStr,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return eventsString.String()
|
||||||
|
}
|
@ -146,9 +146,6 @@ var overview = statusOutputOverview{
|
|||||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||||
TransferReceived: 200,
|
TransferReceived: 200,
|
||||||
TransferSent: 100,
|
TransferSent: 100,
|
||||||
Routes: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Networks: []string{
|
Networks: []string{
|
||||||
"10.1.0.0/24",
|
"10.1.0.0/24",
|
||||||
},
|
},
|
||||||
@ -176,6 +173,7 @@ var overview = statusOutputOverview{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Events: []systemEventOutput{},
|
||||||
CliVersion: version.NetbirdVersion(),
|
CliVersion: version.NetbirdVersion(),
|
||||||
DaemonVersion: "0.14.1",
|
DaemonVersion: "0.14.1",
|
||||||
ManagementState: managementStateOutput{
|
ManagementState: managementStateOutput{
|
||||||
@ -230,9 +228,6 @@ var overview = statusOutputOverview{
|
|||||||
Error: "timeout",
|
Error: "timeout",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
Networks: []string{
|
Networks: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
@ -299,9 +294,6 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"transferSent": 100,
|
"transferSent": 100,
|
||||||
"latency": 10000000,
|
"latency": 10000000,
|
||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": [
|
|
||||||
"10.1.0.0/24"
|
|
||||||
],
|
|
||||||
"networks": [
|
"networks": [
|
||||||
"10.1.0.0/24"
|
"10.1.0.0/24"
|
||||||
]
|
]
|
||||||
@ -327,7 +319,6 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"transferSent": 1000,
|
"transferSent": 1000,
|
||||||
"latency": 10000000,
|
"latency": 10000000,
|
||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": null,
|
|
||||||
"networks": null
|
"networks": null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -366,9 +357,6 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"fqdn": "some-localhost.awesome-domain.com",
|
"fqdn": "some-localhost.awesome-domain.com",
|
||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"quantumResistancePermissive": false,
|
"quantumResistancePermissive": false,
|
||||||
"routes": [
|
|
||||||
"10.10.0.0/24"
|
|
||||||
],
|
|
||||||
"networks": [
|
"networks": [
|
||||||
"10.10.0.0/24"
|
"10.10.0.0/24"
|
||||||
],
|
],
|
||||||
@ -394,7 +382,8 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"enabled": false,
|
"enabled": false,
|
||||||
"error": "timeout"
|
"error": "timeout"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"events": []
|
||||||
}`
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
|
|
||||||
@ -430,8 +419,6 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
transferSent: 100
|
transferSent: 100
|
||||||
latency: 10ms
|
latency: 10ms
|
||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes:
|
|
||||||
- 10.1.0.0/24
|
|
||||||
networks:
|
networks:
|
||||||
- 10.1.0.0/24
|
- 10.1.0.0/24
|
||||||
- fqdn: peer-2.awesome-domain.com
|
- fqdn: peer-2.awesome-domain.com
|
||||||
@ -452,7 +439,6 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
transferSent: 1000
|
transferSent: 1000
|
||||||
latency: 10ms
|
latency: 10ms
|
||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes: []
|
|
||||||
networks: []
|
networks: []
|
||||||
cliVersion: development
|
cliVersion: development
|
||||||
daemonVersion: 0.14.1
|
daemonVersion: 0.14.1
|
||||||
@ -480,8 +466,6 @@ usesKernelInterface: true
|
|||||||
fqdn: some-localhost.awesome-domain.com
|
fqdn: some-localhost.awesome-domain.com
|
||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
quantumResistancePermissive: false
|
quantumResistancePermissive: false
|
||||||
routes:
|
|
||||||
- 10.10.0.0/24
|
|
||||||
networks:
|
networks:
|
||||||
- 10.10.0.0/24
|
- 10.10.0.0/24
|
||||||
forwardingRules: 0
|
forwardingRules: 0
|
||||||
@ -499,6 +483,7 @@ dnsServers:
|
|||||||
- example.net
|
- example.net
|
||||||
enabled: false
|
enabled: false
|
||||||
error: timeout
|
error: timeout
|
||||||
|
events: []
|
||||||
`
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
@ -528,7 +513,6 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Last WireGuard handshake: %s
|
Last WireGuard handshake: %s
|
||||||
Transfer status (received/sent) 200 B/100 B
|
Transfer status (received/sent) 200 B/100 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.1.0.0/24
|
|
||||||
Networks: 10.1.0.0/24
|
Networks: 10.1.0.0/24
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
@ -545,10 +529,10 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Last WireGuard handshake: %s
|
Last WireGuard handshake: %s
|
||||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: -
|
|
||||||
Networks: -
|
Networks: -
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
|
Events: No events recorded
|
||||||
OS: %s/%s
|
OS: %s/%s
|
||||||
Daemon version: 0.14.1
|
Daemon version: 0.14.1
|
||||||
CLI version: %s
|
CLI version: %s
|
||||||
@ -564,7 +548,6 @@ FQDN: some-localhost.awesome-domain.com
|
|||||||
NetBird IP: 192.168.178.100/16
|
NetBird IP: 192.168.178.100/16
|
||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
@ -587,7 +570,6 @@ FQDN: some-localhost.awesome-domain.com
|
|||||||
NetBird IP: 192.168.178.100/16
|
NetBird IP: 192.168.178.100/16
|
||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
|
137
client/cmd/trace.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var traceCmd = &cobra.Command{
|
||||||
|
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||||
|
Short: "Trace a packet through the firewall",
|
||||||
|
Example: `
|
||||||
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||||
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
|
Args: cobra.ExactArgs(3),
|
||||||
|
RunE: tracePacket,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
debugCmd.AddCommand(traceCmd)
|
||||||
|
|
||||||
|
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||||
|
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||||
|
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||||
|
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||||
|
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||||
|
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||||
|
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||||
|
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||||
|
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||||
|
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||||
|
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||||
|
direction := strings.ToLower(args[0])
|
||||||
|
if direction != "in" && direction != "out" {
|
||||||
|
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := cmd.Flag("protocol").Value.String()
|
||||||
|
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||||
|
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||||
|
}
|
||||||
|
|
||||||
|
sport, err := cmd.Flags().GetUint16("sport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid source port: %v", err)
|
||||||
|
}
|
||||||
|
dport, err := cmd.Flags().GetUint16("dport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||||
|
if protocol != "icmp" {
|
||||||
|
if sport == 0 {
|
||||||
|
sport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
if dport == 0 {
|
||||||
|
dport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpFlags *proto.TCPFlags
|
||||||
|
if protocol == "tcp" {
|
||||||
|
syn, _ := cmd.Flags().GetBool("syn")
|
||||||
|
ack, _ := cmd.Flags().GetBool("ack")
|
||||||
|
fin, _ := cmd.Flags().GetBool("fin")
|
||||||
|
rst, _ := cmd.Flags().GetBool("rst")
|
||||||
|
psh, _ := cmd.Flags().GetBool("psh")
|
||||||
|
urg, _ := cmd.Flags().GetBool("urg")
|
||||||
|
|
||||||
|
tcpFlags = &proto.TCPFlags{
|
||||||
|
Syn: syn,
|
||||||
|
Ack: ack,
|
||||||
|
Fin: fin,
|
||||||
|
Rst: rst,
|
||||||
|
Psh: psh,
|
||||||
|
Urg: urg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||||
|
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||||
|
SourceIp: args[1],
|
||||||
|
DestinationIp: args[2],
|
||||||
|
Protocol: protocol,
|
||||||
|
SourcePort: uint32(sport),
|
||||||
|
DestinationPort: uint32(dport),
|
||||||
|
Direction: direction,
|
||||||
|
TcpFlags: tcpFlags,
|
||||||
|
IcmpType: &icmpType,
|
||||||
|
IcmpCode: &icmpCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
|
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
|
for _, stage := range resp.Stages {
|
||||||
|
if stage.ForwardingDetails != nil {
|
||||||
|
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := map[bool]string{
|
||||||
|
true: "\033[32mALLOWED\033[0m", // Green
|
||||||
|
false: "\033[31mDENIED\033[0m", // Red
|
||||||
|
}[resp.FinalDisposition]
|
||||||
|
|
||||||
|
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||||
|
}
|
@ -14,13 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface)
|
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, err := createNativeFirewall(iface, stateManager)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||||
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface)
|
fm, err := createFW(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,4 +12,6 @@ type IFaceMapper interface {
|
|||||||
Address() device.WGAddress
|
Address() device.WGAddress
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
@ -213,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// AddDNATRule adds a DNAT rule
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
|
@ -152,7 +152,16 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule := genRouteFilteringRuleSpec(params)
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
var err error
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// after the established rule
|
||||||
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
|
} else {
|
||||||
|
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add route rule: %v", err)
|
return nil, fmt.Errorf("add route rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,6 +100,12 @@ type Manager interface {
|
|||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
|
SetLogLevel(log.Level)
|
||||||
|
|
||||||
|
EnableRouting() error
|
||||||
|
|
||||||
|
DisableRouting() error
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// AddDNATRule adds a DNAT rule
|
||||||
AddDNATRule(ForwardRule) (Rule, error)
|
AddDNATRule(ForwardRule) (Rule, error)
|
||||||
|
|
||||||
|
@ -318,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
//
|
//
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
stdout, stderr = runIptablesSave(t)
|
stdout, stderr = runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
|
t.Helper()
|
||||||
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||||
|
_, wantIsCounter := want[i].(*expr.Counter)
|
||||||
|
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -296,7 +296,13 @@ func (r *router) AddRouteFiltering(
|
|||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// TODO: Insert after the established rule
|
||||||
|
rule = r.conn.InsertRule(rule)
|
||||||
|
} else {
|
||||||
rule = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
@ -3,6 +3,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
|
16
client/firewall/uspfilter/common/iface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
@ -15,7 +15,6 @@ type BaseConnTrack struct {
|
|||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestPort uint16
|
DestPort uint16
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||||
established atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
|
||||||
func (b *BaseConnTrack) IsEstablished() bool {
|
|
||||||
return b.established.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
|
||||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
|
||||||
b.established.Store(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
@ -3,8 +3,14 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
func BenchmarkIPOperations(b *testing.B) {
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
func BenchmarkAtomicOperations(b *testing.B) {
|
|
||||||
conn := &BaseConnTrack{}
|
|
||||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.IsEstablished()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("SetEstablished", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn.SetEstablished(i%2 == 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("GetLastSeen", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = conn.GetLastSeen()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
|||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
type ICMPTracker struct {
|
type ICMPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ICMPConnKey]*ICMPConnTrack
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -42,12 +45,13 @@ type ICMPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &ICMPTracker{
|
tracker := &ICMPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
|||||||
// TrackOutbound records an outbound ICMP Echo Request
|
// TrackOutbound records an outbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
|||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
Sequence: seq,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New ICMP connection %v", key)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
switch icmpType {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
|
||||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
|
||||||
return true
|
|
||||||
case uint8(layers.ICMPv4TypeEchoReply):
|
|
||||||
// continue processing
|
|
||||||
default:
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.ID == id &&
|
conn.ID == id &&
|
||||||
conn.Sequence == seq
|
conn.Sequence == seq
|
||||||
@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
@ -5,7 +5,10 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -62,11 +65,23 @@ type TCPConnKey struct {
|
|||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
State TCPState
|
||||||
|
established atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (t *TCPConnTrack) IsEstablished() bool {
|
||||||
|
return t.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||||
|
t.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*TCPConnTrack
|
connections map[ConnKey]*TCPConnTrack
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -76,8 +91,9 @@ type TCPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
|||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
// Create key before lock
|
// Create key before lock
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
},
|
},
|
||||||
State: TCPStateNew,
|
State: TCPStateNew,
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(false)
|
conn.established.Store(false)
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, true)
|
t.updateState(conn, flags, true)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
conn.SetEstablished(false)
|
conn.SetEstablished(false)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
case TCPStateTimeWait:
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||||
// This is handled by the cleanup routine
|
// This is handled by the cleanup routine
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||||
|
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
|||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
|||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
type UDPTracker struct {
|
type UDPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*UDPConnTrack
|
connections map[ConnKey]*UDPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
@ -29,12 +32,13 @@ type UDPTracker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
|||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(true)
|
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
|
|
||||||
|
t.logger.Trace("New UDP connection: %v", conn)
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.lastSeen.Store(now)
|
conn.UpdateLastSeen()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.IsEstablished() &&
|
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
conn.DestPort == srcPort &&
|
conn.DestPort == srcPort &&
|
||||||
conn.SourcePort == dstPort
|
conn.SourcePort == dstPort
|
||||||
@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
|||||||
t.ipPool.Put(conn.SourceIP)
|
t.ipPool.Put(conn.SourceIP)
|
||||||
t.ipPool.Put(conn.DestIP)
|
t.ipPool.Put(conn.DestIP)
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout)
|
tracker := NewUDPTracker(tt.timeout, logger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
|||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.True(t, conn.IsEstablished())
|
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1 * time.Second)
|
tracker := NewUDPTracker(1*time.Second, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
|
type endpoint struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
device *wgdevice.Device
|
||||||
|
mtu uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
var written int
|
||||||
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
|
if data == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet through WireGuard
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Wait() {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultReceiveWindow = 32768
|
||||||
|
defaultMaxInFlight = 1024
|
||||||
|
iosReceiveWindow = 16384
|
||||||
|
iosMaxInFlight = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *endpoint
|
||||||
|
udpForwarder *udpForwarder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ip net.IP
|
||||||
|
netstack bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||||
|
s := stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
icmp.NewProtocol4,
|
||||||
|
},
|
||||||
|
HandleLocal: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
mtu, err := iface.GetDevice().MTU()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get MTU: %w", err)
|
||||||
|
}
|
||||||
|
nicID := tcpip.NICID(1)
|
||||||
|
endpoint := &endpoint{
|
||||||
|
logger: logger,
|
||||||
|
device: iface.GetWGDevice(),
|
||||||
|
mtu: uint32(mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := iface.Address().Network.Mask.Size()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
|
PrefixLen: ones,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
|
}
|
||||||
|
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetRouteTable([]tcpip.Route{
|
||||||
|
{
|
||||||
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &Forwarder{
|
||||||
|
logger: logger,
|
||||||
|
stack: s,
|
||||||
|
endpoint: endpoint,
|
||||||
|
udpForwarder: newUDPForwarder(mtu, logger),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
netstack: netstack,
|
||||||
|
ip: iface.Address().IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
receiveWindow := defaultReceiveWindow
|
||||||
|
maxInFlight := defaultMaxInFlight
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
receiveWindow = iosReceiveWindow
|
||||||
|
maxInFlight = iosMaxInFlight
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||||
|
|
||||||
|
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if f.endpoint.dispatcher != nil {
|
||||||
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the forwarder
|
||||||
|
func (f *Forwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
if f.udpForwarder != nil {
|
||||||
|
f.udpForwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stack.Close()
|
||||||
|
f.stack.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
|
return net.IPv4(127, 0, 0, 1)
|
||||||
|
}
|
||||||
|
return addr.AsSlice()
|
||||||
|
}
|
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleICMP handles ICMP packets from the network stack
|
||||||
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
lc := net.ListenConfig{}
|
||||||
|
// TODO: support non-root
|
||||||
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||||
|
|
||||||
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
|
// Get the complete ICMP message (header + data)
|
||||||
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
|
||||||
|
// For Echo Requests, send and handle response
|
||||||
|
switch icmpHdr.Type() {
|
||||||
|
case header.ICMPv4Echo:
|
||||||
|
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||||
|
case header.ICMPv4EchoReply:
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||||
|
_, err = conn.WriteTo(payload, dst)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||||
|
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||||
|
id, icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, f.endpoint.mtu)
|
||||||
|
n, _, err := conn.ReadFrom(response)
|
||||||
|
if err != nil {
|
||||||
|
if !isTimeout(err) {
|
||||||
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
|
ip := header.IPv4(ipHdr)
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||||
|
SrcAddr: id.LocalAddress,
|
||||||
|
DstAddr: id.RemoteAddress,
|
||||||
|
})
|
||||||
|
ip.SetChecksum(^ip.CalculateChecksum())
|
||||||
|
|
||||||
|
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||||
|
fullPacket = append(fullPacket, ipHdr...)
|
||||||
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||||
|
return true
|
||||||
|
}
|
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
|
if err != nil {
|
||||||
|
r.Complete(true)
|
||||||
|
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
r.Complete(false)
|
||||||
|
|
||||||
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||||
|
|
||||||
|
go f.proxyTCP(id, inConn, outConn, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
ep.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create context for managing the proxy goroutines
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(outConn, inConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(inConn, outConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
288
client/firewall/uspfilter/forwarder/udp.go
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
conn *gonet.UDPConn
|
||||||
|
outConn net.Conn
|
||||||
|
lastSeen atomic.Int64
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpForwarder struct {
|
||||||
|
sync.RWMutex
|
||||||
|
logger *nblog.Logger
|
||||||
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
|
bufPool sync.Pool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type idleConn struct {
|
||||||
|
id stack.TransportEndpointID
|
||||||
|
conn *udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &udpForwarder{
|
||||||
|
logger: logger,
|
||||||
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, mtu)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go f.cleanup()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the UDP forwarder and all active connections
|
||||||
|
func (f *udpForwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ep.Close()
|
||||||
|
delete(f.conns, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle UDP connections
|
||||||
|
func (f *udpForwarder) cleanup() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
var idleConns []idleConn
|
||||||
|
|
||||||
|
f.RLock()
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
if conn.getIdleDuration() > udpTimeout {
|
||||||
|
idleConns = append(idleConns, idleConn{id, conn})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.RUnlock()
|
||||||
|
|
||||||
|
for _, idle := range idleConns {
|
||||||
|
idle.conn.cancel()
|
||||||
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idle.conn.ep.Close()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
delete(f.conns, idle.id)
|
||||||
|
f.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
f.udpForwarder.RLock()
|
||||||
|
_, exists := f.udpForwarder.conns[id]
|
||||||
|
f.udpForwarder.RUnlock()
|
||||||
|
if exists {
|
||||||
|
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
// TODO: Send ICMP error message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
|
|
||||||
|
pConn := &udpPacketConn{
|
||||||
|
conn: inConn,
|
||||||
|
outConn: outConn,
|
||||||
|
cancel: connCancel,
|
||||||
|
ep: ep,
|
||||||
|
}
|
||||||
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
// Double-check no connection was created while we were setting up
|
||||||
|
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
pConn.cancel()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||||
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
pConn.cancel()
|
||||||
|
if err := pConn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, id)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||||
|
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||||
|
bufp := bufPool.Get().(*[]byte)
|
||||||
|
defer bufPool.Put(bufp)
|
||||||
|
buffer := *bufp
|
||||||
|
|
||||||
|
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if isTimeout(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read from %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.updateLastSeen()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedError(err error) bool {
|
||||||
|
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
134
client/firewall/uspfilter/localip.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type localIPManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||||
|
ipv4Bitmap [1 << 16]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalIPManager() *localIPManager {
|
||||||
|
return &localIPManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
if int(high) >= len(*newIPv4Bitmap) {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
ipStr := ip.String()
|
||||||
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
|
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
|
log.Debugf("process IP failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var newIPv4Bitmap [1 << 16]uint32
|
||||||
|
ipv4Set := make(map[string]struct{})
|
||||||
|
var ipv4Addresses []string
|
||||||
|
|
||||||
|
// 127.0.0.0/8
|
||||||
|
high := uint16(127) << 8
|
||||||
|
for i := uint16(0); i < 256; i++ {
|
||||||
|
newIPv4Bitmap[high|i] = 0xffffffff
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
return m.checkBitmapBit(ipv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
270
client/firewall/uspfilter/localip_test.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalIPManager(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupAddr iface.WGAddress
|
||||||
|
testIP net.IP
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Localhost range",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.2"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost standard address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.0.0.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost range edge",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("127.255.255.255"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP matches",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
setupAddr: iface.WGAddress{
|
||||||
|
IP: net.ParseIP("fe80::1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("fe80::"),
|
||||||
|
Mask: net.CIDRMask(64, 128),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: net.ParseIP("fe80::1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
|
mock := &IFaceMock{
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return tt.setupAddr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := manager.IsLocalIP(tt.testIP)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
mock := &IFaceMock{}
|
||||||
|
|
||||||
|
// Get actual local interfaces
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tests []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all local interface IPs to test cases
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip4.String(),
|
||||||
|
expected: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some external IPs as negative test cases
|
||||||
|
externalIPs := []string{
|
||||||
|
"8.8.8.8",
|
||||||
|
"1.1.1.1",
|
||||||
|
"208.67.222.222",
|
||||||
|
}
|
||||||
|
for _, ip := range externalIPs {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip,
|
||||||
|
expected: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotEmpty(t, tests, "No test cases generated")
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||||
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap version
|
||||||
|
bitmapManager := &localIPManager{
|
||||||
|
ipv4Bitmap: [1 << 16]uint32{},
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
196
client/firewall/uspfilter/log/log.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||||
|
maxMessageSize = 1024 * 2 // 2KB per message
|
||||||
|
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||||
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Level represents log severity
|
||||||
|
type Level uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelPanic Level = iota
|
||||||
|
LevelFatal
|
||||||
|
LevelError
|
||||||
|
LevelWarn
|
||||||
|
LevelInfo
|
||||||
|
LevelDebug
|
||||||
|
LevelTrace
|
||||||
|
)
|
||||||
|
|
||||||
|
var levelStrings = map[Level]string{
|
||||||
|
LevelPanic: "PANC",
|
||||||
|
LevelFatal: "FATL",
|
||||||
|
LevelError: "ERRO",
|
||||||
|
LevelWarn: "WARN",
|
||||||
|
LevelInfo: "INFO",
|
||||||
|
LevelDebug: "DEBG",
|
||||||
|
LevelTrace: "TRAC",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
buffer *ringBuffer
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Reusable buffer pool for formatting messages
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
|
l := &Logger{
|
||||||
|
output: logrusLogger.Out,
|
||||||
|
buffer: newRingBuffer(bufferSize),
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
// Pre-allocate buffer for message formatting
|
||||||
|
b := make([]byte, 0, maxMessageSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
|
l.level.Store(uint32(logrusLevel))
|
||||||
|
level := levelStrings[Level(logrusLevel)]
|
||||||
|
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||||
|
|
||||||
|
l.wg.Add(1)
|
||||||
|
go l.worker()
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) SetLevel(level Level) {
|
||||||
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
|
||||||
|
// Timestamp
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Level
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
// Message
|
||||||
|
if len(args) > 0 {
|
||||||
|
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||||
|
} else {
|
||||||
|
*buf = append(*buf, format...)
|
||||||
|
}
|
||||||
|
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
l.formatMessage(bufp, level, format, args...)
|
||||||
|
|
||||||
|
if len(*bufp) > maxMessageSize {
|
||||||
|
*bufp = (*bufp)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
_, _ = l.buffer.Write(*bufp)
|
||||||
|
|
||||||
|
l.bufPool.Put(bufp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
|
l.log(LevelError, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
|
l.log(LevelWarn, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
|
l.log(LevelInfo, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
|
l.log(LevelDebug, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||||
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
|
l.log(LevelTrace, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker periodically flushes the buffer
|
||||||
|
func (l *Logger) worker() {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buf := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.shutdown:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Read accumulated messages
|
||||||
|
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write batch
|
||||||
|
_, _ = l.output.Write(buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the logger
|
||||||
|
func (l *Logger) Stop(ctx context.Context) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
l.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// ringBuffer is a simple ring buffer implementation
|
||||||
|
type ringBuffer struct {
|
||||||
|
buf []byte
|
||||||
|
size int
|
||||||
|
r, w int64 // Read and write positions
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRingBuffer(size int) *ringBuffer {
|
||||||
|
return &ringBuffer{
|
||||||
|
buf: make([]byte, size),
|
||||||
|
size: size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(p) > r.size {
|
||||||
|
p = p[:r.size]
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(p)
|
||||||
|
|
||||||
|
// Write data, handling wrap-around
|
||||||
|
pos := int(r.w % int64(r.size))
|
||||||
|
writeLen := min(len(p), r.size-pos)
|
||||||
|
copy(r.buf[pos:], p[:writeLen])
|
||||||
|
|
||||||
|
// If we have more data and need to wrap around
|
||||||
|
if writeLen < len(p) {
|
||||||
|
copy(r.buf, p[writeLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update write position
|
||||||
|
r.w += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.w == r.r {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate available data accounting for wraparound
|
||||||
|
available := int(r.w - r.r)
|
||||||
|
if available < 0 {
|
||||||
|
available += r.size
|
||||||
|
}
|
||||||
|
available = min(available, r.size)
|
||||||
|
|
||||||
|
// Limit read to buffer size
|
||||||
|
toRead := min(available, len(p))
|
||||||
|
if toRead == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read data, handling wrap-around
|
||||||
|
pos := int(r.r % int64(r.size))
|
||||||
|
readLen := min(toRead, r.size-pos)
|
||||||
|
n = copy(p, r.buf[pos:pos+readLen])
|
||||||
|
|
||||||
|
// If we need more data and need to wrap around
|
||||||
|
if readLen < toRead {
|
||||||
|
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update read position
|
||||||
|
r.r += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
@ -2,14 +2,15 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type Rule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
ip net.IP
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
@ -23,7 +24,22 @@ type Rule struct {
|
|||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *Rule) ID() string {
|
func (r *PeerRule) ID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteRule struct {
|
||||||
|
id string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the rule id
|
||||||
|
func (r *RouteRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
390
client/firewall/uspfilter/tracer.go
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketStage int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StageReceived PacketStage = iota
|
||||||
|
StageConntrack
|
||||||
|
StagePeerACL
|
||||||
|
StageRouting
|
||||||
|
StageRouteACL
|
||||||
|
StageForwarding
|
||||||
|
StageCompleted
|
||||||
|
)
|
||||||
|
|
||||||
|
const msgProcessingCompleted = "Processing completed"
|
||||||
|
|
||||||
|
func (s PacketStage) String() string {
|
||||||
|
return map[PacketStage]string{
|
||||||
|
StageReceived: "Received",
|
||||||
|
StageConntrack: "Connection Tracking",
|
||||||
|
StagePeerACL: "Peer ACL",
|
||||||
|
StageRouting: "Routing",
|
||||||
|
StageRouteACL: "Route ACL",
|
||||||
|
StageForwarding: "Forwarding",
|
||||||
|
StageCompleted: "Completed",
|
||||||
|
}[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwarderAction struct {
|
||||||
|
Action string
|
||||||
|
RemoteAddr string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceResult struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
Stage PacketStage
|
||||||
|
Message string
|
||||||
|
Allowed bool
|
||||||
|
ForwarderAction *ForwarderAction
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketTrace struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
DestinationIP net.IP
|
||||||
|
Protocol string
|
||||||
|
SourcePort uint16
|
||||||
|
DestinationPort uint16
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
Results []TraceResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPState struct {
|
||||||
|
SYN bool
|
||||||
|
ACK bool
|
||||||
|
FIN bool
|
||||||
|
RST bool
|
||||||
|
PSH bool
|
||||||
|
URG bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketBuilder struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
Protocol fw.Protocol
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
PayloadSize int
|
||||||
|
TCPState *TCPState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
ForwarderAction: action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
|
ip := p.buildIPLayer()
|
||||||
|
pktLayers := []gopacket.SerializableLayer{ip}
|
||||||
|
|
||||||
|
transportLayer, err := p.buildTransportLayer(ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pktLayers = append(pktLayers, transportLayer...)
|
||||||
|
|
||||||
|
if p.PayloadSize > 0 {
|
||||||
|
payload := make([]byte, p.PayloadSize)
|
||||||
|
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializePacket(pktLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||||
|
return &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
|
SrcIP: p.SrcIP,
|
||||||
|
DstIP: p.DstIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
switch p.Protocol {
|
||||||
|
case "tcp":
|
||||||
|
return p.buildTCPLayer(ip)
|
||||||
|
case "udp":
|
||||||
|
return p.buildUDPLayer(ip)
|
||||||
|
case "icmp":
|
||||||
|
return p.buildICMPLayer()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
|
Window: 65535,
|
||||||
|
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||||
|
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||||
|
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||||
|
RST: p.TCPState != nil && p.TCPState.RST,
|
||||||
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
|
}
|
||||||
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
|
}
|
||||||
|
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||||
|
icmp.Id = uint16(1)
|
||||||
|
icmp.Seq = uint16(1)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||||
|
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||||
|
switch protocol {
|
||||||
|
case fw.ProtocolTCP:
|
||||||
|
return int(layers.IPProtocolTCP)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
return int(layers.IPProtocolUDP)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
return int(layers.IPProtocolICMPv4)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||||
|
packetData, err := builder.Build()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.TracePacket(packetData, builder.Direction), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
|
// Initial packet decoding
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base packet info
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
trace.SourceIP = srcIP
|
||||||
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
|
// Determine protocol and ports
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
trace.Protocol = "TCP"
|
||||||
|
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
trace.Protocol = "UDP"
|
||||||
|
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
trace.Protocol = "ICMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
|
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||||
|
|
||||||
|
if direction == fw.RuleDirectionOUT {
|
||||||
|
return m.traceOutbound(packetData, trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||||
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.handleRouting(trace) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.nativeRouter {
|
||||||
|
return m.handleNativeRouter(trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||||
|
msg := "No existing connection found"
|
||||||
|
if allowed {
|
||||||
|
msg = m.buildConntrackStateMessage(d)
|
||||||
|
trace.AddResult(StageConntrack, msg, true)
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trace.AddResult(StageConntrack, msg, false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||||
|
msg := "Matched existing connection state"
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||||
|
flags&conntrack.TCPSyn != 0,
|
||||||
|
flags&conntrack.TCPAck != 0,
|
||||||
|
flags&conntrack.TCPRst != 0,
|
||||||
|
flags&conntrack.TCPFin != 0)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
|
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
msg := "Allowed by peer ACL rules"
|
||||||
|
if blocked {
|
||||||
|
msg = "Blocked by peer ACL rules"
|
||||||
|
}
|
||||||
|
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||||
|
|
||||||
|
if m.netstack {
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
|
if !m.routingEnabled {
|
||||||
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||||
|
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||||
|
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
|
msg := "Allowed by route ACLs"
|
||||||
|
if !allowed {
|
||||||
|
msg = "Blocked by route ACLs"
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
|
if allowed && m.forwarder != nil {
|
||||||
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||||
|
fwdAction := &ForwarderAction{
|
||||||
|
Action: action,
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
}
|
||||||
|
trace.AddResultWithForwarder(StageForwarding,
|
||||||
|
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
|
// will create or update the connection state
|
||||||
|
dropped := m.processOutgoingHooks(packetData)
|
||||||
|
if dropped {
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
|
} else {
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
@ -6,7 +6,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -15,29 +17,53 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
const (
|
||||||
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
|
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
|
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||||
|
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||||
|
|
||||||
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||||
|
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||||
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errRouteNotSupported = errors.New("route not supported with userspace firewall")
|
errRouteNotSupported = errors.New("route not supported with userspace firewall")
|
||||||
errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]Rule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
|
type RouteRules []RouteRule
|
||||||
|
|
||||||
|
func (r RouteRules) Sort() {
|
||||||
|
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||||
|
// Deny rules come first
|
||||||
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.id, b.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@ -45,17 +71,34 @@ type Manager struct {
|
|||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[string]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// indicates whether server routes are disabled
|
||||||
|
disableServerRoutes bool
|
||||||
|
// indicates whether we forward packets not destined for ourselves
|
||||||
|
routingEnabled bool
|
||||||
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
|
nativeRouter bool
|
||||||
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
|
// indicates whether wireguards runs in netstack mode
|
||||||
|
netstack bool
|
||||||
|
// indicates whether we forward local traffic to the native stack
|
||||||
|
localForwarding bool
|
||||||
|
|
||||||
|
localipmanager *localIPManager
|
||||||
|
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
|
forwarder *forwarder.Forwarder
|
||||||
|
logger *nblog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -72,22 +115,44 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||||
return create(iface)
|
return create(iface, nil, disableServerRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
mgr, err := create(iface)
|
if nativeFirewall == nil {
|
||||||
|
return nil, errors.New("native firewall is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.nativeFirewall = nativeFirewall
|
|
||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func parseCreateEnv() (bool, bool) {
|
||||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
var disableConntrack, enableLocalForwarding bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||||
|
disableConntrack, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return disableConntrack, enableLocalForwarding
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
@ -103,53 +168,184 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
|
localipmanager: newLocalIPManager(),
|
||||||
|
disableServerRoutes: disableServerRoutes,
|
||||||
|
routingEnabled: false,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
netstack: netstack.IsEnabled(),
|
||||||
|
// default true for non-netstack, for netstack only if explicitly enabled
|
||||||
|
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// netstack needs the forwarder for local traffic
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
if err := m.initForwarder(); err != nil {
|
||||||
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.blockInvalidRouted(iface); err != nil {
|
||||||
|
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse wireguard network: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
|
if _, err := m.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
|
wgPrefix,
|
||||||
|
firewall.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.ActionDrop,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("block wg nte : %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) determineRouting() error {
|
||||||
|
var disableUspRouting, forceUserspaceRouter bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||||
|
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case disableUspRouting:
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
|
case m.disableServerRoutes:
|
||||||
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
|
case forceUserspaceRouter:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
|
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||||
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = true
|
||||||
|
|
||||||
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
|
default:
|
||||||
|
m.routingEnabled = true
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
log.Info("userspace routing enabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.routingEnabled && !m.nativeRouter {
|
||||||
|
return m.initForwarder()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
|
func (m *Manager) initForwarder() error {
|
||||||
|
if m.forwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
|
intf := m.wgIface.GetWGDevice()
|
||||||
|
if intf == nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return errors.New("forwarding not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||||
|
if err != nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder = forwarder
|
||||||
|
|
||||||
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Init(*statemanager.Manager) error {
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userspace routed packets are always SNATed to the inbound direction
|
||||||
|
// TODO: implement outbound SNAT
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
@ -164,7 +360,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
_ string,
|
_ string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
@ -207,26 +403,64 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
if m.nativeFirewall == nil {
|
sources []netip.Prefix,
|
||||||
return nil, errRouteNotSupported
|
destination netip.Prefix,
|
||||||
}
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
m.mutex.Lock()
|
||||||
if m.nativeFirewall == nil {
|
defer m.mutex.Unlock()
|
||||||
return errRouteNotSupported
|
|
||||||
|
ruleID := uuid.New().String()
|
||||||
|
rule := RouteRule{
|
||||||
|
id: ruleID,
|
||||||
|
sources: sources,
|
||||||
|
destination: destination,
|
||||||
|
proto: proto,
|
||||||
|
srcPort: sPort,
|
||||||
|
dstPort: dPort,
|
||||||
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.routeRules = append(m.routeRules, rule)
|
||||||
|
m.routeRules.Sort()
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeRouter && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := rule.ID()
|
||||||
|
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||||
|
return r.id == ruleID
|
||||||
|
})
|
||||||
|
if idx < 0 {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*PeerRule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
@ -273,10 +507,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
|||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.incomingRules)
|
return m.dropFilter(packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
|
func (m *Manager) UpdateLocalIPs() error {
|
||||||
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
@ -297,18 +535,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always process UDP hooks
|
// Track all protocols if stateful mode is enabled
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
|
||||||
// Track UDP state only if enabled
|
|
||||||
if m.stateful {
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track other protocols only if stateful mode is enabled
|
|
||||||
if m.stateful {
|
if m.stateful {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@ -316,6 +547,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process UDP hooks even if stateful mode is disabled
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,10 +633,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
// If it returns true, the packet should be dropped.
|
||||||
// TODO: Disable router if --disable-server-router is set
|
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@ -413,39 +648,127 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if srcIP == nil {
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
return false
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection state only if enabled
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.applyRules(srcIP, packetData, rules, d)
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalTraffic handles local traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
if !m.localForwarding {
|
||||||
|
m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
|
if m.netstack {
|
||||||
|
m.handleNetstackLocalTraffic(packetData)
|
||||||
|
|
||||||
|
// don't process this packet further
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRoutedTraffic handles routed traffic.
|
||||||
|
// If it returns true, the packet should be dropped.
|
||||||
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
|
// Drop if routing is disabled
|
||||||
|
if !m.routingEnabled {
|
||||||
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass to native stack if native router is enabled or forced
|
||||||
|
if m.nativeRouter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||||
|
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||||
|
srcIP, srcPort, dstIP, dstPort, proto)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return firewall.ProtocolTCP
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return firewall.ProtocolUDP
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return firewall.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return firewall.ProtocolALL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||||
|
default:
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
|
||||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
@ -480,7 +803,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||||
|
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||||
|
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType := d.icmp4.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
|
if m.isSpecialICMP(d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||||
return filter
|
return filter
|
||||||
}
|
}
|
||||||
@ -514,7 +852,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
@ -551,6 +889,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||||
|
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||||
|
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||||
|
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||||
|
return rule.action == firewall.ActionAccept
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
if !rule.destination.Contains(dstAddr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceMatched := false
|
||||||
|
for _, src := range rule.sources {
|
||||||
|
if src.Contains(srcAddr) {
|
||||||
|
sourceMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sourceMatched {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||||
|
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
m.wgNetwork = network
|
m.wgNetwork = network
|
||||||
@ -562,7 +945,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||||
) string {
|
) string {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
@ -579,12 +962,12 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
@ -617,3 +1000,41 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(level log.Level) {
|
||||||
|
if m.logger != nil {
|
||||||
|
m.logger.SetLevel(nblog.Level(level))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.determineRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
// don't stop forwarder if in use by netstack
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder.Stop()
|
||||||
|
m.forwarder = nil
|
||||||
|
|
||||||
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
//go:build uspbench
|
||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, manager.incomingRules)
|
manager.dropFilter(testIn)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack)
|
||||||
@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, manager.incomingRules)
|
manager.dropFilter(inbound)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, manager.incomingRules)
|
manager.dropFilter(synack)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
manager.dropFilter(inPackets[connIdx])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn)
|
||||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
manager.dropFilter(p.synAck)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request)
|
||||||
manager.dropFilter(p.response, manager.incomingRules)
|
manager.dropFilter(p.response)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient)
|
||||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
manager.dropFilter(p.ackServer)
|
||||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
manager.dropFilter(p.finServer)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
|||||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouteACLs(b *testing.B) {
|
||||||
|
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||||
|
|
||||||
|
// Add several route rules to simulate real-world scenario
|
||||||
|
rules := []struct {
|
||||||
|
sources []netip.Prefix
|
||||||
|
dest netip.Prefix
|
||||||
|
proto fw.Protocol
|
||||||
|
port *fw.Port
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
port: &fw.Port{Values: []uint16{80, 443}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/12"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
port: &fw.Port{Values: []uint16{53}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
_, err := manager.AddRouteFiltering(
|
||||||
|
r.sources,
|
||||||
|
r.dest,
|
||||||
|
r.proto,
|
||||||
|
nil,
|
||||||
|
r.port,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases that exercise different matching scenarios
|
||||||
|
cases := []struct {
|
||||||
|
srcIP string
|
||||||
|
dstIP string
|
||||||
|
proto fw.Protocol
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||||
|
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||||
|
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||||
|
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for _, tc := range cases {
|
||||||
|
srcIP := net.ParseIP(tc.srcIP)
|
||||||
|
dstIP := net.ParseIP(tc.dstIP)
|
||||||
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
@ -9,17 +9,38 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() iface.WGAddress
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
|
if i.GetWGDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||||
|
if i.GetDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||||
@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.incomingRules) {
|
if m.dropFilter(buf.Bytes()) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface)
|
manager, err := Create(iface, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Reset(nil))
|
||||||
}()
|
}()
|
||||||
@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
drop = manager.dropFilter(testBuf.Bytes())
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -15,4 +17,5 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -151,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device, not applicable for kernel devices
|
||||||
|
func (t *TunKernelDevice) Device() *device.Device {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
|
|||||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunNetstackDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
@ -124,6 +124,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *USPDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *USPDevice) assignAddr() error {
|
func (t *USPDevice) assignAddr() error {
|
||||||
link := newWGLink(t.name)
|
link := newWGLink(t.name)
|
||||||
|
@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||||
if t.nativeTunDevice == nil {
|
if t.nativeTunDevice == nil {
|
||||||
return "", fmt.Errorf("interface has not been initialized yet")
|
return "", fmt.Errorf("interface has not been initialized yet")
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -13,4 +15,5 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return w.tun.FilteredDevice()
|
return w.tun.FilteredDevice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWGDevice returns the WireGuard device
|
||||||
|
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return w.tun.Device()
|
||||||
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats(peerKey)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -29,6 +30,7 @@ type MockWGIface struct {
|
|||||||
SetFilterFunc func(filter device.PacketFilter) error
|
SetFilterFunc func(filter device.PacketFilter) error
|
||||||
GetFilterFunc func() device.PacketFilter
|
GetFilterFunc func() device.PacketFilter
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return m.GetDeviceFunc()
|
return m.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return m.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return m.GetStatsFunc(peerKey)
|
return m.GetStatsFunc(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||||
//TODO implement me
|
return m.GetProxyFunc()
|
||||||
panic("implement me")
|
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -32,5 +33,6 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -30,6 +31,7 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
iface "github.com/netbirdio/netbird/client/iface"
|
iface "github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetDevice")
|
||||||
|
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevice indicates an expected call of GetDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice mocks base method.
|
||||||
|
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||||
|
ret0, _ := ret[0].(*wgdevice.Device)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||||
|
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||||
|
}
|
||||||
|
@ -68,6 +68,8 @@ type ConfigInput struct {
|
|||||||
DisableFirewall *bool
|
DisableFirewall *bool
|
||||||
|
|
||||||
BlockLANAccess *bool
|
BlockLANAccess *bool
|
||||||
|
|
||||||
|
DisableNotifications *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@ -93,6 +95,8 @@ type Config struct {
|
|||||||
|
|
||||||
BlockLANAccess bool
|
BlockLANAccess bool
|
||||||
|
|
||||||
|
DisableNotifications bool
|
||||||
|
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
@ -469,6 +473,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.DisableNotifications != nil && *input.DisableNotifications != config.DisableNotifications {
|
||||||
|
if *input.DisableNotifications {
|
||||||
|
log.Infof("disabling notifications")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling notifications")
|
||||||
|
}
|
||||||
|
config.DisableNotifications = *input.DisableNotifications
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.ClientCertKeyPath != "" {
|
if input.ClientCertKeyPath != "" {
|
||||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||||
updated = true
|
updated = true
|
||||||
|
@ -31,6 +31,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -109,6 +110,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
|
nbnet.Init()
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityDNSRoute = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityMatchDomain = 50
|
||||||
PriorityDefault = 0
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type SubdomainMatcher interface {
|
type SubdomainMatcher interface {
|
||||||
@ -26,7 +26,6 @@ type HandlerEntry struct {
|
|||||||
Pattern string
|
Pattern string
|
||||||
OrigPattern string
|
OrigPattern string
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
StopHandler handlerWithStop
|
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
|
||||||
c.handlers[i].StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
OrigPattern: origPattern,
|
OrigPattern: origPattern,
|
||||||
IsWildcard: isWildcard,
|
IsWildcard: isWildcard,
|
||||||
StopHandler: stopHandler,
|
|
||||||
MatchSubdomains: matchSubdomains,
|
MatchSubdomains: matchSubdomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
|
||||||
entry.StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !matched {
|
if !matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
|
@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
dnsRouteHandler := &nbdns.MockHandler{}
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
pattern = "*." + tt.handlerDomain[2:]
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3 := &nbdns.MockHandler{}
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers[op.priority] = handler
|
handlers[op.priority] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state with all three handlers
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
handler = mockHandler
|
handler = mockHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
chain.AddHandler(pattern, handler, h.priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
@ -795,7 +795,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||||
handlers[op.pattern] = handler
|
handlers[op.pattern] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
|
@ -1,35 +1,51 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||||
|
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||||
|
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
|
||||||
|
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
|
||||||
|
|
||||||
dnsPolicyConfigVersionKey = "Version"
|
dnsPolicyConfigVersionKey = "Version"
|
||||||
dnsPolicyConfigVersionValue = 2
|
dnsPolicyConfigVersionValue = 2
|
||||||
dnsPolicyConfigNameKey = "Name"
|
dnsPolicyConfigNameKey = "Name"
|
||||||
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
||||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
|
||||||
|
// RP_FORCE: Reapply all policies even if no policy change was detected
|
||||||
|
rpForce = 0x1
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryConfigurator struct {
|
type registryConfigurator struct {
|
||||||
guid string
|
guid string
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
gpo bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@ -37,12 +53,20 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newHostManagerWithGuid(guid)
|
|
||||||
|
var useGPO bool
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, gpoDnsPolicyRoot, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to open GPO DNS policy root: %v", err)
|
||||||
|
} else {
|
||||||
|
closer(k)
|
||||||
|
useGPO = true
|
||||||
|
log.Infof("detected GPO DNS policy configuration, using policy store")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
|
||||||
return ®istryConfigurator{
|
return ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
|
gpo: useGPO,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,30 +75,23 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add dns setup: %w", err)
|
return fmt.Errorf("add dns setup: %w", err)
|
||||||
}
|
}
|
||||||
} else if r.routingAll {
|
} else if r.routingAll {
|
||||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = false
|
r.routingAll = false
|
||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var searchDomains, matchDomains []string
|
||||||
searchDomains []string
|
|
||||||
matchDomains []string
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
if dConf.Disabled {
|
if dConf.Disabled {
|
||||||
continue
|
continue
|
||||||
@ -86,16 +103,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
|
||||||
} else {
|
|
||||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add dns match policy: %w", err)
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
|
return fmt.Errorf("remove dns match policies: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = r.updateSearchDomains(searchDomains)
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,9 +120,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
|
||||||
}
|
}
|
||||||
r.routingAll = true
|
r.routingAll = true
|
||||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||||
@ -113,64 +129,66 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
||||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
if err == nil {
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
if r.gpo {
|
||||||
if err != nil {
|
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||||
|
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
|
log.Warnf("failed to refresh group policy: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||||
|
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreHostDNS() error {
|
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error {
|
||||||
log.Errorf("remove registry key from dns policy config: %s", err)
|
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||||
|
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||||
|
}
|
||||||
|
defer closer(regKey)
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigVersionKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetStringsValue(dnsPolicyConfigNameKey, domains); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
return fmt.Errorf("adding search domain failed with error: %w", err)
|
|
||||||
}
|
}
|
||||||
|
log.Infof("updated search domains: %s", domains)
|
||||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,11 +199,9 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
|||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.SetStringValue(key, value)
|
if err := regKey.SetStringValue(key, value); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("set key %s=%s: %w", key, value, err)
|
||||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,43 +212,91 @@ func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey st
|
|||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.DeleteValue(propertyKey)
|
if err := regKey.DeleteValue(propertyKey); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("delete registry key %s: %w", propertyKey, err)
|
||||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||||
var regKey registry.Key
|
|
||||||
|
|
||||||
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
||||||
|
|
||||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
return regKey, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return regKey, nil
|
return regKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
func (r *registryConfigurator) restoreHostDNS() error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
log.Errorf("remove dns match policies: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||||
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("refresh group policy: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
|
return r.restoreHostDNS()
|
||||||
|
}
|
||||||
|
|
||||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||||
if err == nil {
|
|
||||||
defer closer(k)
|
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
log.Debugf("failed to open HKEY_LOCAL_MACHINE\\%s: %v", regKeyPath, err)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
closer(k)
|
||||||
|
if err := registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath); err != nil {
|
||||||
|
return fmt.Errorf("delete HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshGroupPolicy() error {
|
||||||
|
// refreshPolicyExFn.Call() panics if the func is not found
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Errorf("Recovered from panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ret, _, err := refreshPolicyExFn.Call(
|
||||||
|
// bMachine = TRUE (computer policy)
|
||||||
|
uintptr(1),
|
||||||
|
// dwOptions = RP_FORCE
|
||||||
|
uintptr(rpForce),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret == 0 {
|
||||||
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
|
return fmt.Errorf("RefreshPolicyEx failed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("RefreshPolicyEx failed")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -29,10 +30,15 @@ func (d *localResolver) String() string {
|
|||||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (d *localResolver) id() handlerID {
|
||||||
|
return "local-resolver"
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) > 0 {
|
if len(r.Question) > 0 {
|
||||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
@ -55,6 +61,7 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||||
question := r.Question[0]
|
question := r.Question[0]
|
||||||
|
question.Name = strings.ToLower(question.Name)
|
||||||
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
|
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
|
||||||
if !found {
|
if !found {
|
||||||
return nil
|
return nil
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -42,7 +41,12 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type handlerID string
|
||||||
|
|
||||||
|
type nsGroupsByDomain struct {
|
||||||
|
domain string
|
||||||
|
groups []*nbdns.NameServerGroup
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
@ -52,7 +56,6 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
handlerPriorities map[string]int
|
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
@ -77,14 +80,17 @@ type handlerWithStop interface {
|
|||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
stop()
|
||||||
probeAvailability()
|
probeAvailability()
|
||||||
|
id() handlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type handlerWrapper struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@ -164,7 +170,6 @@ func newDefaultServer(
|
|||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
handlerPriorities: make(map[string]int),
|
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
log.Warn("skipping empty domain")
|
log.Warn("skipping empty domain")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.handlerPriorities[domain] = priority
|
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
|
||||||
if domain == "" {
|
if domain == "" {
|
||||||
log.Warn("skipping empty domain")
|
log.Warn("skipping empty domain")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
|
// Only deregister from service if no handlers remain
|
||||||
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
|
|||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
// Check if there's any root handler
|
||||||
if ok {
|
var hasRootHandler bool
|
||||||
|
for _, handler := range s.dnsMuxMap {
|
||||||
|
if handler.domain == nbdns.RootZone {
|
||||||
|
hasRootHandler = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasRootHandler {
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
go func(mux handlerWithStop) {
|
go func(mux handlerWithStop) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
mux.probeAvailability()
|
mux.probeAvailability()
|
||||||
}(mux)
|
}(mux.handler)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
|
||||||
var muxUpdates []muxUpdate
|
var muxUpdates []handlerWrapper
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@ -439,6 +454,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
if record.Class != nbdns.DefaultClass {
|
if record.Class != nbdns.DefaultClass {
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||||
localRecords[key] = record
|
localRecords[key] = record
|
||||||
}
|
}
|
||||||
@ -446,15 +462,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return muxUpdates, localRecords, nil
|
return muxUpdates, localRecords, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
for _, nsGroup := range nameServerGroups {
|
||||||
if len(nsGroup.NameServers) == 0 {
|
if len(nsGroup.NameServers) == 0 {
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||||
|
|
||||||
|
for _, domainGroup := range groupedNS {
|
||||||
|
basePriority := PriorityMatchDomain
|
||||||
|
if domainGroup.domain == nbdns.RootZone {
|
||||||
|
basePriority = PriorityDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, updates...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
|
for i, nsGroup := range domainGroup.groups {
|
||||||
|
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||||
|
priority := basePriority - i
|
||||||
|
|
||||||
|
// Check if we're about to overlap with the next priority tier
|
||||||
|
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||||
|
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||||
|
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface.Name(),
|
||||||
@ -462,10 +522,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
domainGroup.domain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||||
@ -489,78 +551,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
|
||||||
|
|
||||||
if nsGroup.Primary {
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
domain: domainGroup.domain,
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
handler: handler,
|
||||||
priority: PriorityDefault,
|
priority: priority,
|
||||||
})
|
})
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
priority: PriorityMatchDomain,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return muxUpdates, nil
|
return muxUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
|
for _, existing := range s.dnsMuxMap {
|
||||||
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
|
existing.handler.stop()
|
||||||
|
}
|
||||||
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
handlersByPriority := make(map[string]int)
|
var containsRootUpdate bool
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
|
||||||
|
|
||||||
// First register new handlers
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
handlersByPriority[update.domain] = update.priority
|
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.domain == nbdns.RootZone {
|
if update.domain == nbdns.RootZone {
|
||||||
isContainRootUpdate = true
|
containsRootUpdate = true
|
||||||
}
|
}
|
||||||
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
|
muxUpdateMap[update.handler.id()] = update
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then deregister old handlers not in the update
|
// If there's no root update and we had a root handler, restore it
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
if !containsRootUpdate {
|
||||||
_, found := muxUpdateMap[key]
|
for _, existing := range s.dnsMuxMap {
|
||||||
if !found {
|
if existing.domain == nbdns.RootZone {
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
existingHandler.stop()
|
break
|
||||||
} else {
|
|
||||||
existingHandler.stop()
|
|
||||||
// Deregister with the priority that was used to register
|
|
||||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
|
||||||
s.deregisterHandler([]string{key}, oldPriority)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
s.handlerPriorities = handlersByPriority
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@ -593,6 +624,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
func (s *DefaultServer) upstreamCallbacks(
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
nsGroup *nbdns.NameServerGroup,
|
nsGroup *nbdns.NameServerGroup,
|
||||||
handler dns.Handler,
|
handler dns.Handler,
|
||||||
|
priority int,
|
||||||
) (deactivate func(error), reactivate func()) {
|
) (deactivate func(error), reactivate func()) {
|
||||||
var removeIndex map[string]int
|
var removeIndex map[string]int
|
||||||
deactivate = func(err error) {
|
deactivate = func(err error) {
|
||||||
@ -609,13 +641,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
s.deregisterHandler([]string{item.Domain}, priority)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -635,8 +667,8 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, err, false)
|
s.updateNSState(nsGroup, err, false)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
reactivate = func() {
|
reactivate = func() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@ -646,7 +678,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
s.registerHandler([]string{domain}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@ -654,7 +686,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
if s.hostManager != nil {
|
||||||
@ -676,6 +708,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
nbdns.RootZone,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
@ -732,5 +765,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
|||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||||
|
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||||
|
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||||
|
|
||||||
|
for _, group := range nsGroups {
|
||||||
|
if group.Primary {
|
||||||
|
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range group.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domainMap[domain] = append(domainMap[domain], group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []nsGroupsByDomain
|
||||||
|
for domain, groups := range domainMap {
|
||||||
|
result = append(result, nsGroupsByDomain{
|
||||||
|
domain: domain,
|
||||||
|
groups: groups,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@ -88,6 +89,18 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
|
var srvs []string
|
||||||
|
for _, srv := range servers {
|
||||||
|
srvs = append(srvs, getNSHostPort(srv))
|
||||||
|
}
|
||||||
|
return &upstreamResolverBase{
|
||||||
|
domain: domain,
|
||||||
|
upstreamServers: srvs,
|
||||||
|
cancel: func() {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@ -140,13 +153,35 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
dummyHandler.id(): handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@ -164,7 +199,18 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"local-resolver": handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -244,7 +290,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
@ -254,7 +306,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
|
"id1": handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &localResolver{},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
@ -563,7 +627,6 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
handlerPriorities: make(map[string]int),
|
|
||||||
hostManager: hostManager,
|
hostManager: hostManager,
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil, 0)
|
||||||
|
|
||||||
deactivate(nil)
|
deactivate(nil)
|
||||||
expected := "domain0,domain2"
|
expected := "domain0,domain2"
|
||||||
@ -849,7 +912,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface)
|
pf, err := uspfilter.Create(wgIface, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
Subdomains: true,
|
Subdomains: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockHandler struct {
|
||||||
|
Id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
|
func (m *mockHandler) stop() {}
|
||||||
|
func (m *mockHandler) probeAvailability() {}
|
||||||
|
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||||
|
|
||||||
|
type mockService struct{}
|
||||||
|
|
||||||
|
func (m *mockService) Listen() error { return nil }
|
||||||
|
func (m *mockService) Stop() {}
|
||||||
|
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
||||||
|
func (m *mockService) RuntimePort() int { return 53 }
|
||||||
|
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||||
|
func (m *mockService) DeregisterMux(string) {}
|
||||||
|
|
||||||
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||||
|
baseMatchHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseRootHandlers := registeredHandlerMap{
|
||||||
|
"upstream-root1": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
"upstream-root2": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseMixedHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
"upstream-other": {
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialHandlers registeredHandlerMap
|
||||||
|
updates []handlerWrapper
|
||||||
|
expectedHandlers map[string]string // map[handlerID]domain
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Remove group1 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group2 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove group2 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group1 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in first position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Add group3 with highest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain + 1,
|
||||||
|
},
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in last position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
// Add group3 with lowest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
// Root zone tests
|
||||||
|
{
|
||||||
|
name: "Remove root1 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root2": ".",
|
||||||
|
},
|
||||||
|
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove root2 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
},
|
||||||
|
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in first position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault + 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in last position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
// Mixed domain tests
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - remove one of duplicate domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - add new domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "new.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-new",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
"upstream-new": "new.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &DefaultServer{
|
||||||
|
dnsMuxMap: tt.initialHandlers,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the update
|
||||||
|
server.updateMux(tt.updates)
|
||||||
|
|
||||||
|
// Verify the results
|
||||||
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||||
|
"Number of handlers after update doesn't match expected")
|
||||||
|
|
||||||
|
// Check each expected handler
|
||||||
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
|
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||||
|
assert.True(t, exists, "Expected handler %s not found", id)
|
||||||
|
if exists {
|
||||||
|
assert.Equal(t, expectedDomain, handler.domain,
|
||||||
|
"Domain mismatch for handler %s", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected handlers exist
|
||||||
|
for handlerID := range server.dnsMuxMap {
|
||||||
|
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||||
|
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the handlerChain state and order
|
||||||
|
previousPriority := 0
|
||||||
|
for _, chainEntry := range server.handlerChain.handlers {
|
||||||
|
// Verify priority order
|
||||||
|
if previousPriority > 0 {
|
||||||
|
assert.True(t, chainEntry.Priority <= previousPriority,
|
||||||
|
"Handlers in chain not properly ordered by priority")
|
||||||
|
}
|
||||||
|
previousPriority = chainEntry.Priority
|
||||||
|
|
||||||
|
// Verify handler exists in mux
|
||||||
|
foundInMux := false
|
||||||
|
for _, muxEntry := range server.dnsMuxMap {
|
||||||
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||||
|
foundInMux = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundInMux,
|
||||||
|
"Handler in chain not found in dnsMuxMap")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
Guid string
|
Guid string
|
||||||
|
GPO bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@ -13,9 +14,9 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
manager, err := newHostManagerWithGuid(s.Guid)
|
manager := ®istryConfigurator{
|
||||||
if err != nil {
|
guid: s.Guid,
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
gpo: s.GPO,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
|
@ -2,9 +2,13 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -15,6 +19,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -40,6 +45,7 @@ type upstreamResolverBase struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
|
domain string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
successCount atomic.Int32
|
successCount atomic.Int32
|
||||||
@ -53,12 +59,13 @@ type upstreamResolverBase struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: upstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
@ -71,6 +78,17 @@ func (u *upstreamResolverBase) String() string {
|
|||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (u *upstreamResolverBase) id() handlerID {
|
||||||
|
servers := slices.Clone(u.upstreamServers)
|
||||||
|
slices.Sort(servers)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
hash.Write([]byte(u.domain + ":"))
|
||||||
|
hash.Write([]byte(strings.Join(servers, ",")))
|
||||||
|
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -87,7 +105,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
r.SetEdns0(4096, false)
|
||||||
@ -96,6 +114,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
|
log.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -112,41 +131,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("got an error while connecting to upstream")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
continue
|
||||||
Error("got other error while querying the upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil {
|
if rm == nil || !rm.Response {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("no response from upstream")
|
continue
|
||||||
return
|
|
||||||
}
|
|
||||||
// those checks need to be independent of each other due to memory address issues
|
|
||||||
if !rm.Response {
|
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
|
||||||
Warn("no response from upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
err = w.WriteMsg(rm)
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
if err != nil {
|
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||||
@ -217,6 +231,14 @@ func (u *upstreamResolverBase) probeAvailability() {
|
|||||||
// didn't find a working upstream server, let's disable and try later
|
// didn't find a working upstream server, let's disable and try later
|
||||||
if !success {
|
if !success {
|
||||||
u.disable(errors.ErrorOrNil())
|
u.disable(errors.ErrorOrNil())
|
||||||
|
|
||||||
|
u.statusRecorder.PublishEvent(
|
||||||
|
proto.SystemEvent_WARNING,
|
||||||
|
proto.SystemEvent_DNS,
|
||||||
|
"All upstream servers failed",
|
||||||
|
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||||
|
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,8 +27,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
hostsDNSHolder: hostsDNSHolder,
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
@ -23,8 +23,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,9 @@ func newUpstreamResolver(
|
|||||||
net *net.IPNet,
|
net *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
|
@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cancelCTX bool
|
cancelCTX bool
|
||||||
expectedAnswer string
|
expectedAnswer string
|
||||||
|
acceptNXDomain bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
@ -40,7 +41,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||||
timeout: 200 * time.Millisecond,
|
timeout: 200 * time.Millisecond,
|
||||||
responseShouldBeNil: true,
|
acceptNXDomain: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||||
@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// should resolve if first upstream times out
|
|
||||||
// should not write when both fails
|
|
||||||
// should not resolve if parent context is canceled
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
@ -84,6 +82,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
t.Fatalf("should write a response message")
|
t.Fatalf("should write a response message")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if testCase.expectedAnswer != "" {
|
||||||
foundAnswer := false
|
foundAnswer := false
|
||||||
for _, answer := range responseMSG.Answer {
|
for _, answer := range responseMSG.Answer {
|
||||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||||
@ -95,6 +98,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
if !foundAnswer {
|
if !foundAnswer {
|
||||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,13 +43,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
@ -195,6 +195,10 @@ type Peer struct {
|
|||||||
WgAllowedIps string
|
WgAllowedIps string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type localIpUpdater interface {
|
||||||
|
UpdateLocalIPs() error
|
||||||
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
clientCtx context.Context,
|
clientCtx context.Context,
|
||||||
@ -442,7 +446,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@ -892,6 +896,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.firewall != nil {
|
||||||
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
|
log.Errorf("failed to update local IPs: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DNS forwarder
|
// DNS forwarder
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
@ -1460,6 +1472,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
return e.routeManager
|
return e.routeManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFirewallManager returns the firewall manager
|
||||||
|
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||||
|
return e.firewall
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1671,6 +1688,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
return nm, nil
|
return nm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWgAddr returns the wireguard address
|
||||||
|
func (e *Engine) GetWgAddr() net.IP {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.wgInterface.Address().IP
|
||||||
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||||
if !enabled {
|
if !enabled {
|
||||||
|
@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -28,12 +29,28 @@ import (
|
|||||||
|
|
||||||
type ConnPriority int
|
type ConnPriority int
|
||||||
|
|
||||||
|
func (cp ConnPriority) String() string {
|
||||||
|
switch cp {
|
||||||
|
case connPriorityNone:
|
||||||
|
return "None"
|
||||||
|
case connPriorityRelay:
|
||||||
|
return "PriorityRelay"
|
||||||
|
case connPriorityICETurn:
|
||||||
|
return "PriorityICETurn"
|
||||||
|
case connPriorityICEP2P:
|
||||||
|
return "PriorityICEP2P"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
|
|
||||||
|
connPriorityNone ConnPriority = 0
|
||||||
connPriorityRelay ConnPriority = 1
|
connPriorityRelay ConnPriority = 1
|
||||||
connPriorityICETurn ConnPriority = 1
|
connPriorityICETurn ConnPriority = 2
|
||||||
connPriorityICEP2P ConnPriority = 2
|
connPriorityICEP2P ConnPriority = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
@ -66,14 +83,6 @@ type ConnConfig struct {
|
|||||||
ICEConfig icemaker.Config
|
ICEConfig icemaker.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerCallbacks struct {
|
|
||||||
OnRelayReadyCallback func(info RelayConnInfo)
|
|
||||||
OnRelayStatusChanged func(ConnStatus)
|
|
||||||
|
|
||||||
OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
|
|
||||||
OnICEStatusChanged func(ConnStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -135,21 +144,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
semaphore: semaphore,
|
semaphore: semaphore,
|
||||||
}
|
}
|
||||||
|
|
||||||
rFns := WorkerRelayCallbacks{
|
|
||||||
OnConnReady: conn.relayConnectionIsReady,
|
|
||||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
wFns := WorkerICECallbacks{
|
|
||||||
OnConnReady: conn.iCEConnectionIsReady,
|
|
||||||
OnStatusChanged: conn.onWorkerICEStateDisconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := isController(config)
|
ctrl := isController(config)
|
||||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
|
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
|
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -304,7 +303,7 @@ func (conn *Conn) GetKey() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@ -317,9 +316,10 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||||
|
// todo consider to remove this check
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
|
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||||
conn.statusICE.Set(StatusConnected)
|
conn.statusICE.Set(StatusConnected)
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
@ -375,8 +375,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo review to make sense to handle connecting and disconnected status also?
|
func (conn *Conn) onICEStateDisconnected() {
|
||||||
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@ -384,7 +383,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
conn.log.Tracef("ICE connection state changed to disconnected")
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
if conn.wgProxyICE != nil {
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
@ -394,7 +393,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
@ -402,12 +401,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
|
} else {
|
||||||
|
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||||
|
conn.currentConnPriority = connPriorityNone
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != StatusDisconnected
|
||||||
conn.statusICE.Set(newState)
|
if changed {
|
||||||
|
conn.guard.SetICEConnDisconnected()
|
||||||
conn.guard.SetICEConnDisconnected(changed)
|
}
|
||||||
|
conn.statusICE.Set(StatusDisconnected)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -422,7 +425,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@ -444,7 +447,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
|
|
||||||
if conn.iceP2PIsActive() {
|
if conn.iceP2PIsActive() {
|
||||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
@ -474,7 +477,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onWorkerRelayStateDisconnected() {
|
func (conn *Conn) onRelayDisconnected() {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@ -497,8 +500,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
|
if changed {
|
||||||
|
conn.guard.SetRelayedConnDisconnected()
|
||||||
|
}
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
conn.guard.SetRelayedConnDisconnected(changed)
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
|
@ -29,8 +29,8 @@ type Guard struct {
|
|||||||
isConnectedOnAllWay isConnectedFunc
|
isConnectedOnAllWay isConnectedFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
srWatcher *SRWatcher
|
srWatcher *SRWatcher
|
||||||
relayedConnDisconnected chan bool
|
relayedConnDisconnected chan struct{}
|
||||||
iCEConnDisconnected chan bool
|
iCEConnDisconnected chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||||
@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
|
|||||||
isConnectedOnAllWay: isConnectedFn,
|
isConnectedOnAllWay: isConnectedFn,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
srWatcher: srWatcher,
|
srWatcher: srWatcher,
|
||||||
relayedConnDisconnected: make(chan bool, 1),
|
relayedConnDisconnected: make(chan struct{}, 1),
|
||||||
iCEConnDisconnected: make(chan bool, 1),
|
iCEConnDisconnected: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
|
func (g *Guard) SetRelayedConnDisconnected() {
|
||||||
select {
|
select {
|
||||||
case g.relayedConnDisconnected <- changed:
|
case g.relayedConnDisconnected <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) SetICEConnDisconnected(changed bool) {
|
func (g *Guard) SetICEConnDisconnected() {
|
||||||
select {
|
select {
|
||||||
case g.iCEConnDisconnected <- changed:
|
case g.iCEConnDisconnected <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
|||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
}
|
}
|
||||||
|
|
||||||
case changed := <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.prepareExponentTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
|
|
||||||
case changed := <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.prepareExponentTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
|||||||
g.log.Infof("start listen for reconnect events...")
|
g.log.Infof("start listen for reconnect events...")
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case changed := <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
g.log.Debugf("Relay connection changed, triggering reconnect")
|
||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
case changed := <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.log.Debugf("ICE state changed, try to send new offer")
|
g.log.Debugf("ICE state changed, try to send new offer")
|
||||||
g.triggerOfferSending()
|
g.triggerOfferSending()
|
||||||
case <-srReconnectedChan:
|
case <-srReconnectedChan:
|
||||||
|
@ -7,23 +7,33 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const eventQueueSize = 10
|
||||||
|
|
||||||
type ResolvedDomainInfo struct {
|
type ResolvedDomainInfo struct {
|
||||||
Prefixes []netip.Prefix
|
Prefixes []netip.Prefix
|
||||||
ParentDomain domain.Domain
|
ParentDomain domain.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EventListener interface {
|
||||||
|
OnEvent(event *proto.SystemEvent)
|
||||||
|
}
|
||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
@ -161,6 +171,10 @@ type Status struct {
|
|||||||
|
|
||||||
relayMgr *relayClient.Manager
|
relayMgr *relayClient.Manager
|
||||||
|
|
||||||
|
eventMux sync.RWMutex
|
||||||
|
eventStreams map[string]chan *proto.SystemEvent
|
||||||
|
eventQueue *EventQueue
|
||||||
|
|
||||||
ingressGwMgr *ingressgw.Manager
|
ingressGwMgr *ingressgw.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,6 +183,8 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
return &Status{
|
return &Status{
|
||||||
peers: make(map[string]State),
|
peers: make(map[string]State),
|
||||||
changeNotify: make(map[string]chan struct{}),
|
changeNotify: make(map[string]chan struct{}),
|
||||||
|
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||||
|
eventQueue: NewEventQueue(eventQueueSize),
|
||||||
offlinePeers: make([]State, 0),
|
offlinePeers: make([]State, 0),
|
||||||
notifier: newNotifier(),
|
notifier: newNotifier(),
|
||||||
mgmAddress: mgmAddress,
|
mgmAddress: mgmAddress,
|
||||||
@ -754,7 +770,9 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
|||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return d.nsGroupStates
|
|
||||||
|
// shallow copy is good enough, as slices fields are currently not updated
|
||||||
|
return slices.Clone(d.nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
@ -838,3 +856,112 @@ func (d *Status) notifyAddressChanged() {
|
|||||||
func (d *Status) numOfPeers() int {
|
func (d *Status) numOfPeers() int {
|
||||||
return len(d.peers) + len(d.offlinePeers)
|
return len(d.peers) + len(d.offlinePeers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublishEvent adds an event to the queue and distributes it to all subscribers
|
||||||
|
func (d *Status) PublishEvent(
|
||||||
|
severity proto.SystemEvent_Severity,
|
||||||
|
category proto.SystemEvent_Category,
|
||||||
|
msg string,
|
||||||
|
userMsg string,
|
||||||
|
metadata map[string]string,
|
||||||
|
) {
|
||||||
|
event := &proto.SystemEvent{
|
||||||
|
Id: uuid.New().String(),
|
||||||
|
Severity: severity,
|
||||||
|
Category: category,
|
||||||
|
Message: msg,
|
||||||
|
UserMessage: userMsg,
|
||||||
|
Metadata: metadata,
|
||||||
|
Timestamp: timestamppb.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
d.eventMux.Lock()
|
||||||
|
defer d.eventMux.Unlock()
|
||||||
|
|
||||||
|
d.eventQueue.Add(event)
|
||||||
|
|
||||||
|
for _, stream := range d.eventStreams {
|
||||||
|
select {
|
||||||
|
case stream <- event:
|
||||||
|
default:
|
||||||
|
log.Debugf("event stream buffer full, skipping event: %v", event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("event published: %v", event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeToEvents returns a new event subscription
|
||||||
|
func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||||
|
d.eventMux.Lock()
|
||||||
|
defer d.eventMux.Unlock()
|
||||||
|
|
||||||
|
id := uuid.New().String()
|
||||||
|
stream := make(chan *proto.SystemEvent, 10)
|
||||||
|
d.eventStreams[id] = stream
|
||||||
|
|
||||||
|
return &EventSubscription{
|
||||||
|
id: id,
|
||||||
|
events: stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsubscribeFromEvents removes an event subscription
|
||||||
|
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||||
|
if sub == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d.eventMux.Lock()
|
||||||
|
defer d.eventMux.Unlock()
|
||||||
|
|
||||||
|
if stream, exists := d.eventStreams[sub.id]; exists {
|
||||||
|
close(stream)
|
||||||
|
delete(d.eventStreams, sub.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEventHistory returns all events in the queue
|
||||||
|
func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||||
|
return d.eventQueue.GetAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventQueue struct {
|
||||||
|
maxSize int
|
||||||
|
events []*proto.SystemEvent
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEventQueue(size int) *EventQueue {
|
||||||
|
return &EventQueue{
|
||||||
|
maxSize: size,
|
||||||
|
events: make([]*proto.SystemEvent, 0, size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *EventQueue) Add(event *proto.SystemEvent) {
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
q.events = append(q.events, event)
|
||||||
|
|
||||||
|
if len(q.events) > q.maxSize {
|
||||||
|
q.events = q.events[len(q.events)-q.maxSize:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *EventQueue) GetAll() []*proto.SystemEvent {
|
||||||
|
q.mutex.RLock()
|
||||||
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
|
return slices.Clone(q.events)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventSubscription struct {
|
||||||
|
id string
|
||||||
|
events chan *proto.SystemEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||||
|
return s.events
|
||||||
|
}
|
||||||
|
154
client/internal/peer/wg_watcher.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wgHandshakePeriod = 3 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
wgHandshakeOvertime = 30 * time.Second // allowed delay in network
|
||||||
|
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
|
||||||
|
)
|
||||||
|
|
||||||
|
type WGInterfaceStater interface {
|
||||||
|
GetStats(key string) (configurer.WGStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WGWatcher struct {
|
||||||
|
log *log.Entry
|
||||||
|
wgIfaceStater WGInterfaceStater
|
||||||
|
peerKey string
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
ctxLock sync.Mutex
|
||||||
|
waitGroup sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
|
||||||
|
return &WGWatcher{
|
||||||
|
log: log,
|
||||||
|
wgIfaceStater: wgIfaceStater,
|
||||||
|
peerKey: peerKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||||
|
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||||
|
w.log.Debugf("enable WireGuard watcher")
|
||||||
|
w.ctxLock.Lock()
|
||||||
|
defer w.ctxLock.Unlock()
|
||||||
|
|
||||||
|
if w.ctx != nil && w.ctx.Err() == nil {
|
||||||
|
w.log.Errorf("WireGuard watcher already enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||||
|
w.ctx = ctx
|
||||||
|
w.ctxCancel = ctxCancel
|
||||||
|
|
||||||
|
initialHandshake, err := w.wgState()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.waitGroup.Add(1)
|
||||||
|
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||||
|
func (w *WGWatcher) DisableWgWatcher() {
|
||||||
|
w.ctxLock.Lock()
|
||||||
|
defer w.ctxLock.Unlock()
|
||||||
|
|
||||||
|
if w.ctxCancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.log.Debugf("disable WireGuard watcher")
|
||||||
|
|
||||||
|
w.ctxCancel()
|
||||||
|
w.ctxCancel = nil
|
||||||
|
w.waitGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||||
|
w.log.Infof("WireGuard watcher started")
|
||||||
|
defer w.waitGroup.Done()
|
||||||
|
|
||||||
|
timer := time.NewTimer(wgHandshakeOvertime)
|
||||||
|
defer timer.Stop()
|
||||||
|
defer ctxCancel()
|
||||||
|
|
||||||
|
lastHandshake := initialHandshake
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
handshake, ok := w.handshakeCheck(lastHandshake)
|
||||||
|
if !ok {
|
||||||
|
onDisconnectedFn()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastHandshake = *handshake
|
||||||
|
|
||||||
|
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||||
|
timer.Reset(resetTime)
|
||||||
|
|
||||||
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
|
case <-ctx.Done():
|
||||||
|
w.log.Infof("WireGuard watcher stopped")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one
|
||||||
|
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
||||||
|
handshake, err := w.wgState()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to read wg stats: %v", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
||||||
|
|
||||||
|
// the current know handshake did not change
|
||||||
|
if handshake.Equal(lastHandshake) {
|
||||||
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case if the machine is suspended, the handshake time will be in the past
|
||||||
|
if handshake.Add(checkPeriod).Before(time.Now()) {
|
||||||
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// error handling for handshake time in the future
|
||||||
|
if handshake.After(time.Now()) {
|
||||||
|
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &handshake, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGWatcher) wgState() (time.Time, error) {
|
||||||
|
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return wgState.LastHandshake, nil
|
||||||
|
}
|
98
client/internal/peer/wg_watcher_test.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocWgIface struct {
|
||||||
|
initial bool
|
||||||
|
lastHandshake time.Time
|
||||||
|
stop bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
|
||||||
|
if !m.initial {
|
||||||
|
m.initial = true
|
||||||
|
return configurer.WGStats{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.stop {
|
||||||
|
m.lastHandshake = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := configurer.WGStats{
|
||||||
|
LastHandshake: m.lastHandshake,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocWgIface) disconnect() {
|
||||||
|
m.stop = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||||
|
checkPeriod = 5 * time.Second
|
||||||
|
wgHandshakeOvertime = 1 * time.Second
|
||||||
|
|
||||||
|
mlog := log.WithField("peer", "tet")
|
||||||
|
mocWgIface := &MocWgIface{}
|
||||||
|
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
onDisconnected := make(chan struct{}, 1)
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {
|
||||||
|
mlog.Infof("onDisconnectedFn")
|
||||||
|
onDisconnected <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
// wait for initial reading
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
mocWgIface.disconnect()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-onDisconnected:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||||
|
checkPeriod = 5 * time.Second
|
||||||
|
wgHandshakeOvertime = 1 * time.Second
|
||||||
|
|
||||||
|
mlog := log.WithField("peer", "tet")
|
||||||
|
mocWgIface := &MocWgIface{}
|
||||||
|
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
onDisconnected := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {})
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {
|
||||||
|
onDisconnected <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
mocWgIface.disconnect()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-onDisconnected:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
watcher.DisableWgWatcher()
|
||||||
|
}
|
@ -31,20 +31,15 @@ type ICEConnInfo struct {
|
|||||||
RelayedOnLocal bool
|
RelayedOnLocal bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerICECallbacks struct {
|
|
||||||
OnConnReady func(ConnPriority, ICEConnInfo)
|
|
||||||
OnStatusChanged func(ConnStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WorkerICE struct {
|
type WorkerICE struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
|
conn *Conn
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
hasRelayOnLocally bool
|
hasRelayOnLocally bool
|
||||||
conn WorkerICECallbacks
|
|
||||||
|
|
||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
muxAgent sync.Mutex
|
muxAgent sync.Mutex
|
||||||
@ -60,16 +55,16 @@ type WorkerICE struct {
|
|||||||
lastKnownState ice.ConnectionState
|
lastKnownState ice.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||||
w := &WorkerICE{
|
w := &WorkerICE{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
conn: conn,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
iFaceDiscover: ifaceDiscover,
|
iFaceDiscover: ifaceDiscover,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
hasRelayOnLocally: hasRelayOnLocally,
|
hasRelayOnLocally: hasRelayOnLocally,
|
||||||
conn: callBacks,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
||||||
@ -154,8 +149,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
Relayed: isRelayed(pair),
|
Relayed: isRelayed(pair),
|
||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
w.log.Debugf("on ICE conn read to use ready")
|
w.log.Debugf("on ICE conn is ready to use")
|
||||||
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
@ -220,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.OnStatusChanged(StatusDisconnected)
|
w.conn.onICEStateDisconnected()
|
||||||
}
|
}
|
||||||
w.closeAgent(agentCancel)
|
w.closeAgent(agentCancel)
|
||||||
default:
|
default:
|
||||||
|
@ -6,52 +6,41 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
wgHandshakePeriod = 3 * time.Minute
|
|
||||||
wgHandshakeOvertime = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type RelayConnInfo struct {
|
type RelayConnInfo struct {
|
||||||
relayedConn net.Conn
|
relayedConn net.Conn
|
||||||
rosenpassPubKey []byte
|
rosenpassPubKey []byte
|
||||||
rosenpassAddr string
|
rosenpassAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
type WorkerRelayCallbacks struct {
|
|
||||||
OnConnReady func(RelayConnInfo)
|
|
||||||
OnDisconnected func()
|
|
||||||
}
|
|
||||||
|
|
||||||
type WorkerRelay struct {
|
type WorkerRelay struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isController bool
|
isController bool
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
|
conn *Conn
|
||||||
relayManager relayClient.ManagerService
|
relayManager relayClient.ManagerService
|
||||||
callBacks WorkerRelayCallbacks
|
|
||||||
|
|
||||||
relayedConn net.Conn
|
relayedConn net.Conn
|
||||||
relayLock sync.Mutex
|
relayLock sync.Mutex
|
||||||
ctxWgWatch context.Context
|
|
||||||
ctxCancelWgWatch context.CancelFunc
|
|
||||||
ctxLock sync.Mutex
|
|
||||||
|
|
||||||
relaySupportedOnRemotePeer atomic.Bool
|
relaySupportedOnRemotePeer atomic.Bool
|
||||||
|
|
||||||
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
log: log,
|
log: log,
|
||||||
isController: ctrl,
|
isController: ctrl,
|
||||||
config: config,
|
config: config,
|
||||||
|
conn: conn,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
callBacks: callbacks,
|
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -87,7 +76,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
w.relayedConn = relayedConn
|
w.relayedConn = relayedConn
|
||||||
w.relayLock.Unlock()
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
|
err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add close listener: %s", err)
|
log.Errorf("failed to add close listener: %s", err)
|
||||||
_ = relayedConn.Close()
|
_ = relayedConn.Close()
|
||||||
@ -95,7 +84,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||||
go w.callBacks.OnConnReady(RelayConnInfo{
|
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||||
relayedConn: relayedConn,
|
relayedConn: relayedConn,
|
||||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||||
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
@ -103,32 +92,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||||
w.log.Debugf("enable WireGuard watcher")
|
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
|
||||||
w.ctxLock.Lock()
|
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(ctx)
|
|
||||||
w.ctxWgWatch = ctx
|
|
||||||
w.ctxCancelWgWatch = ctxCancel
|
|
||||||
|
|
||||||
w.wgStateCheck(ctx, ctxCancel)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) DisableWgWatcher() {
|
func (w *WorkerRelay) DisableWgWatcher() {
|
||||||
w.ctxLock.Lock()
|
w.wgWatcher.DisableWgWatcher()
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxCancelWgWatch == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Debugf("disable WireGuard watcher")
|
|
||||||
|
|
||||||
w.ctxCancelWgWatch()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||||
@ -150,57 +118,17 @@ func (w *WorkerRelay) CloseConn() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := w.relayedConn.Close()
|
if err := w.relayedConn.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to close relay connection: %v", err)
|
w.log.Warnf("failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
func (w *WorkerRelay) onWGDisconnected() {
|
||||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
|
|
||||||
w.log.Debugf("WireGuard watcher started")
|
|
||||||
lastHandshake, err := w.wgState()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to read wg stats: %v", err)
|
|
||||||
lastHandshake = time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(lastHandshake time.Time) {
|
|
||||||
timer := time.NewTimer(wgHandshakeOvertime)
|
|
||||||
defer timer.Stop()
|
|
||||||
defer ctxCancel()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
handshake, err := w.wgState()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to read wg stats: %v", err)
|
|
||||||
timer.Reset(wgHandshakeOvertime)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
|
||||||
|
|
||||||
if handshake.Equal(lastHandshake) {
|
|
||||||
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
|
||||||
w.relayLock.Lock()
|
w.relayLock.Lock()
|
||||||
_ = w.relayedConn.Close()
|
_ = w.relayedConn.Close()
|
||||||
w.relayLock.Unlock()
|
w.relayLock.Unlock()
|
||||||
w.callBacks.OnDisconnected()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
|
|
||||||
lastHandshake = handshake
|
|
||||||
timer.Reset(resetTime)
|
|
||||||
case <-ctx.Done():
|
|
||||||
w.log.Debugf("WireGuard watcher stopped")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(lastHandshake)
|
|
||||||
|
|
||||||
|
w.conn.onRelayDisconnected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||||
@ -217,20 +145,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
|||||||
return remoteRelayAddress
|
return remoteRelayAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) wgState() (time.Time, error) {
|
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||||
wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
|
w.wgWatcher.DisableWgWatcher()
|
||||||
if err != nil {
|
go w.conn.onRelayDisconnected()
|
||||||
return time.Time{}, err
|
|
||||||
}
|
|
||||||
return wgState.LastHandshake, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerRelay) onRelayMGDisconnected() {
|
|
||||||
w.ctxLock.Lock()
|
|
||||||
defer w.ctxLock.Unlock()
|
|
||||||
|
|
||||||
if w.ctxCancelWgWatch != nil {
|
|
||||||
w.ctxCancelWgWatch()
|
|
||||||
}
|
|
||||||
go w.callBacks.OnDisconnected()
|
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +29,15 @@ const (
|
|||||||
handlerTypeStatic
|
handlerTypeStatic
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type reason int
|
||||||
|
|
||||||
|
const (
|
||||||
|
reasonUnknown reason = iota
|
||||||
|
reasonRouteUpdate
|
||||||
|
reasonPeerUpdate
|
||||||
|
reasonShutdown
|
||||||
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
@ -255,7 +265,7 @@ func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error {
|
||||||
if c.currentChosen == nil {
|
if c.currentChosen == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -269,17 +279,19 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.disconnectEvent(rsn)
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
|
||||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||||
|
|
||||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||||
|
|
||||||
// If no route is chosen, remove the route from the peer and system
|
// If no route is chosen, remove the route from the peer and system
|
||||||
if newChosenID == "" {
|
if newChosenID == "" {
|
||||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil {
|
||||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,6 +331,58 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||||
|
var defaultRoute bool
|
||||||
|
for _, r := range c.routes {
|
||||||
|
if r.Network.Bits() == 0 {
|
||||||
|
defaultRoute = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !defaultRoute {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var severity proto.SystemEvent_Severity
|
||||||
|
var message string
|
||||||
|
var userMessage string
|
||||||
|
meta := make(map[string]string)
|
||||||
|
|
||||||
|
switch rsn {
|
||||||
|
case reasonShutdown:
|
||||||
|
severity = proto.SystemEvent_INFO
|
||||||
|
message = "Default route removed"
|
||||||
|
userMessage = "Exit node disconnected."
|
||||||
|
meta["network"] = c.handler.String()
|
||||||
|
case reasonRouteUpdate:
|
||||||
|
severity = proto.SystemEvent_INFO
|
||||||
|
message = "Default route updated due to configuration change"
|
||||||
|
meta["network"] = c.handler.String()
|
||||||
|
case reasonPeerUpdate:
|
||||||
|
severity = proto.SystemEvent_WARNING
|
||||||
|
message = "Default route disconnected due to peer unreachability"
|
||||||
|
userMessage = "Exit node connection lost. Your internet access might be affected."
|
||||||
|
if c.currentChosen != nil {
|
||||||
|
meta["peer"] = c.currentChosen.Peer
|
||||||
|
meta["network"] = c.handler.String()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
severity = proto.SystemEvent_ERROR
|
||||||
|
message = "Default route disconnected for unknown reason"
|
||||||
|
userMessage = "Exit node disconnected for unknown reasons."
|
||||||
|
meta["network"] = c.handler.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.statusRecorder.PublishEvent(
|
||||||
|
severity,
|
||||||
|
proto.SystemEvent_NETWORK,
|
||||||
|
message,
|
||||||
|
userMessage,
|
||||||
|
meta,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||||
go func() {
|
go func() {
|
||||||
c.routeUpdate <- update
|
c.routeUpdate <- update
|
||||||
@ -361,12 +425,12 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
select {
|
select {
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
|
||||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-c.peerStateUpdate:
|
case <-c.peerStateUpdate:
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||||
}
|
}
|
||||||
@ -385,7 +449,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
|
|
||||||
if isTrueRouteUpdate {
|
if isTrueRouteUpdate {
|
||||||
log.Debug("Client network update contains different routes, recalculating routes")
|
log.Debug("Client network update contains different routes, recalculating routes")
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||||
}
|
}
|
||||||
|
@ -113,13 +113,14 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
disableServerRoutes: config.DisableServerRoutes,
|
disableServerRoutes: config.DisableServerRoutes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||||
|
dm.setupRefCounters(useNoop)
|
||||||
|
|
||||||
// don't proceed with client routes if it is disabled
|
// don't proceed with client routes if it is disabled
|
||||||
if config.DisableClientRoutes {
|
if config.DisableClientRoutes {
|
||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.setupRefCounters()
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
@ -127,7 +128,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters() {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
@ -137,7 +138,7 @@ func (m *DefaultManager) setupRefCounters() {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if useNoop {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(netip.Prefix, struct{}) (struct{}, error) {
|
func(netip.Prefix, struct{}) (struct{}, error) {
|
||||||
return struct{}{}, refcounter.ErrIgnore
|
return struct{}{}, refcounter.ErrIgnore
|
||||||
@ -285,16 +286,16 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.clientRoutes = newClientRoutesIDMap
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
|
if m.serverRouter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
|
||||||
|
return fmt.Errorf("update routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,11 +423,6 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
haID := newRoute.GetHAUniqueID()
|
haID := newRoute.GetHAUniqueID()
|
||||||
if newRoute.Peer == m.pubKey {
|
if newRoute.Peer == m.pubKey {
|
||||||
ownNetworkIDs[haID] = true
|
ownNetworkIDs[haID] = true
|
||||||
// only linux is supported for now
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -454,7 +450,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isRouteSupported(route *route.Route) bool {
|
func isRouteSupported(route *route.Route) bool {
|
||||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +69,16 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
|||||||
m.routes[id] = newRoute
|
m.routes[id] = newRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(m.routes) > 0 {
|
||||||
|
if err := m.firewall.EnableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("enable routing: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.firewall.DisableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("disable routing: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,20 +53,6 @@ type ruleParams struct {
|
|||||||
description string
|
description string
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
|
||||||
func isLegacy() bool {
|
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
|
||||||
}
|
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
|
||||||
func setIsLegacy(b bool) {
|
|
||||||
if b {
|
|
||||||
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
|
||||||
} else {
|
|
||||||
os.Unsetenv("NB_USE_LEGACY_ROUTING")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSetupRules() []ruleParams {
|
func getSetupRules() []ruleParams {
|
||||||
return []ruleParams{
|
return []ruleParams{
|
||||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||||
@ -87,7 +73,7 @@ func getSetupRules() []ruleParams {
|
|||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
@ -103,11 +89,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
|
||||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
|
||||||
setIsLegacy(true)
|
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -130,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +149,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericAddVPNRoute(prefix, intf)
|
return r.genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,7 +172,7 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -504,7 +485,7 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
if isLegacy() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return GetRoutesFromTable()
|
return GetRoutesFromTable()
|
||||||
}
|
}
|
||||||
return nil, ErrRoutingIsSeparate
|
return nil, ErrRoutingIsSeparate
|
||||||
|
@ -85,6 +85,7 @@ var testCases = []testCase{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
|
nbnet.Init()
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
// todo resolve test execution on freebsd
|
// todo resolve test execution on freebsd
|
||||||
if runtime.GOOS == "freebsd" {
|
if runtime.GOOS == "freebsd" {
|
||||||
|
@ -61,6 +61,12 @@ service DaemonService {
|
|||||||
|
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
||||||
|
|
||||||
|
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
||||||
|
|
||||||
|
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||||
|
|
||||||
|
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -118,6 +124,8 @@ message LoginRequest {
|
|||||||
optional bool disable_firewall = 23;
|
optional bool disable_firewall = 23;
|
||||||
|
|
||||||
optional bool block_lan_access = 24;
|
optional bool block_lan_access = 24;
|
||||||
|
|
||||||
|
optional bool disable_notifications = 25;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@ -183,6 +191,8 @@ message GetConfigResponse {
|
|||||||
bool rosenpassEnabled = 11;
|
bool rosenpassEnabled = 11;
|
||||||
|
|
||||||
bool rosenpassPermissive = 12;
|
bool rosenpassPermissive = 12;
|
||||||
|
|
||||||
|
bool disable_notifications = 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@ -254,6 +264,8 @@ message FullStatus {
|
|||||||
repeated RelayState relays = 5;
|
repeated RelayState relays = 5;
|
||||||
repeated NSGroupState dns_servers = 6;
|
repeated NSGroupState dns_servers = 6;
|
||||||
int32 NumberOfForwardingRules = 8;
|
int32 NumberOfForwardingRules = 8;
|
||||||
|
|
||||||
|
repeated SystemEvent events = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
@ -388,3 +400,68 @@ message SetNetworkMapPersistenceRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message SetNetworkMapPersistenceResponse {}
|
message SetNetworkMapPersistenceResponse {}
|
||||||
|
|
||||||
|
message TCPFlags {
|
||||||
|
bool syn = 1;
|
||||||
|
bool ack = 2;
|
||||||
|
bool fin = 3;
|
||||||
|
bool rst = 4;
|
||||||
|
bool psh = 5;
|
||||||
|
bool urg = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TracePacketRequest {
|
||||||
|
string source_ip = 1;
|
||||||
|
string destination_ip = 2;
|
||||||
|
string protocol = 3;
|
||||||
|
uint32 source_port = 4;
|
||||||
|
uint32 destination_port = 5;
|
||||||
|
string direction = 6;
|
||||||
|
optional TCPFlags tcp_flags = 7;
|
||||||
|
optional uint32 icmp_type = 8;
|
||||||
|
optional uint32 icmp_code = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TraceStage {
|
||||||
|
string name = 1;
|
||||||
|
string message = 2;
|
||||||
|
bool allowed = 3;
|
||||||
|
optional string forwarding_details = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TracePacketResponse {
|
||||||
|
repeated TraceStage stages = 1;
|
||||||
|
bool final_disposition = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SubscribeRequest{}
|
||||||
|
|
||||||
|
message SystemEvent {
|
||||||
|
enum Severity {
|
||||||
|
INFO = 0;
|
||||||
|
WARNING = 1;
|
||||||
|
ERROR = 2;
|
||||||
|
CRITICAL = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Category {
|
||||||
|
NETWORK = 0;
|
||||||
|
DNS = 1;
|
||||||
|
AUTHENTICATION = 2;
|
||||||
|
CONNECTIVITY = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
string id = 1;
|
||||||
|
Severity severity = 2;
|
||||||
|
Category category = 3;
|
||||||
|
string message = 4;
|
||||||
|
string userMessage = 5;
|
||||||
|
google.protobuf.Timestamp timestamp = 6;
|
||||||
|
map<string, string> metadata = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetEventsRequest {}
|
||||||
|
|
||||||
|
message GetEventsResponse {
|
||||||
|
repeated SystemEvent events = 1;
|
||||||
|
}
|
||||||
|
@ -52,6 +52,9 @@ type DaemonServiceClient interface {
|
|||||||
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
||||||
|
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||||
|
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
|
||||||
|
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@ -215,6 +218,56 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
|
||||||
|
out := new(TracePacketResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) {
|
||||||
|
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &daemonServiceSubscribeEventsClient{stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type DaemonService_SubscribeEventsClient interface {
|
||||||
|
Recv() (*SystemEvent, error)
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type daemonServiceSubscribeEventsClient struct {
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) {
|
||||||
|
m := new(SystemEvent)
|
||||||
|
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) {
|
||||||
|
out := new(GetEventsResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@ -253,6 +306,9 @@ type DaemonServiceServer interface {
|
|||||||
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence
|
// SetNetworkMapPersistence enables or disables network map persistence
|
||||||
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
||||||
|
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||||
|
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
|
||||||
|
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,6 +367,15 @@ func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStat
|
|||||||
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error {
|
||||||
|
return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@ -630,6 +695,63 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(TracePacketRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).TracePacket(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/TracePacket",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(SubscribeRequest)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
type DaemonService_SubscribeEventsServer interface {
|
||||||
|
Send(*SystemEvent) error
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type daemonServiceSubscribeEventsServer struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error {
|
||||||
|
return x.ServerStream.SendMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(GetEventsRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).GetEvents(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/GetEvents",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@ -705,7 +827,21 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "SetNetworkMapPersistence",
|
MethodName: "SetNetworkMapPersistence",
|
||||||
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "TracePacket",
|
||||||
|
Handler: _DaemonService_TracePacket_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "GetEvents",
|
||||||
|
Handler: _DaemonService_GetEvents_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{
|
||||||
|
{
|
||||||
|
StreamName: "SubscribeEvents",
|
||||||
|
Handler: _DaemonService_SubscribeEvents_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
|
||||||
Metadata: "daemon.proto",
|
Metadata: "daemon.proto",
|
||||||
}
|
}
|
||||||
|
@ -538,7 +538,24 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
|
|
||||||
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("connect client not initialized")
|
||||||
|
}
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager == nil {
|
||||||
|
return nil, fmt.Errorf("firewall manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager.SetLogLevel(level)
|
||||||
|
|
||||||
log.Infof("Log level set to %s", level.String())
|
log.Infof("Log level set to %s", level.String())
|
||||||
|
|
||||||
return &proto.SetLogLevelResponse{}, nil
|
return &proto.SetLogLevelResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
36
client/server/event.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.DaemonService_SubscribeEventsServer) error {
|
||||||
|
subscription := s.statusRecorder.SubscribeToEvents()
|
||||||
|
defer func() {
|
||||||
|
s.statusRecorder.UnsubscribeFromEvents(subscription)
|
||||||
|
log.Debug("client unsubscribed from events")
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Debug("client subscribed to events")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event := <-subscription.Events():
|
||||||
|
if err := stream.Send(event); err != nil {
|
||||||
|
log.Warnf("error sending event to %v: %v", req, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-stream.Context().Done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
||||||
|
events := s.statusRecorder.GetEventHistory()
|
||||||
|
return &proto.GetEventsResponse{Events: events}, nil
|
||||||
|
}
|
@ -404,6 +404,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
s.latestConfigInput.BlockLANAccess = msg.BlockLanAccess
|
s.latestConfigInput.BlockLANAccess = msg.BlockLanAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.DisableNotifications != nil {
|
||||||
|
inputConfig.DisableNotifications = msg.DisableNotifications
|
||||||
|
s.latestConfigInput.DisableNotifications = msg.DisableNotifications
|
||||||
|
}
|
||||||
|
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
if msg.OptionalPreSharedKey != nil {
|
if msg.OptionalPreSharedKey != nil {
|
||||||
@ -687,6 +692,7 @@ func (s *Server) Status(
|
|||||||
|
|
||||||
fullStatus := s.statusRecorder.GetFullStatus()
|
fullStatus := s.statusRecorder.GetFullStatus()
|
||||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||||
|
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||||
statusResponse.FullStatus = pbFullStatus
|
statusResponse.FullStatus = pbFullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -747,6 +753,7 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
|||||||
ServerSSHAllowed: *s.config.ServerSSHAllowed,
|
ServerSSHAllowed: *s.config.ServerSSHAllowed,
|
||||||
RosenpassEnabled: s.config.RosenpassEnabled,
|
RosenpassEnabled: s.config.RosenpassEnabled,
|
||||||
RosenpassPermissive: s.config.RosenpassPermissive,
|
RosenpassPermissive: s.config.RosenpassPermissive,
|
||||||
|
DisableNotifications: s.config.DisableNotifications,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
func (s *Server) onSessionExpire() {
|
func (s *Server) onSessionExpire() {
|
||||||
|
123
client/server/trace.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type packetTracer interface {
|
||||||
|
TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("connect client not initialized")
|
||||||
|
}
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager == nil {
|
||||||
|
return nil, fmt.Errorf("firewall manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer, ok := fwManager.(packetTracer)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := net.ParseIP(req.GetSourceIp())
|
||||||
|
if req.GetSourceIp() == "self" {
|
||||||
|
srcIP = engine.GetWgAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||||
|
if req.GetDestinationIp() == "self" {
|
||||||
|
dstIP = engine.GetWgAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
if srcIP == nil || dstIP == nil {
|
||||||
|
return nil, fmt.Errorf("invalid IP address")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpState *uspfilter.TCPState
|
||||||
|
if flags := req.GetTcpFlags(); flags != nil {
|
||||||
|
tcpState = &uspfilter.TCPState{
|
||||||
|
SYN: flags.GetSyn(),
|
||||||
|
ACK: flags.GetAck(),
|
||||||
|
FIN: flags.GetFin(),
|
||||||
|
RST: flags.GetRst(),
|
||||||
|
PSH: flags.GetPsh(),
|
||||||
|
URG: flags.GetUrg(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dir fw.RuleDirection
|
||||||
|
switch req.GetDirection() {
|
||||||
|
case "in":
|
||||||
|
dir = fw.RuleDirectionIN
|
||||||
|
case "out":
|
||||||
|
dir = fw.RuleDirectionOUT
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid direction")
|
||||||
|
}
|
||||||
|
|
||||||
|
var protocol fw.Protocol
|
||||||
|
switch req.GetProtocol() {
|
||||||
|
case "tcp":
|
||||||
|
protocol = fw.ProtocolTCP
|
||||||
|
case "udp":
|
||||||
|
protocol = fw.ProtocolUDP
|
||||||
|
case "icmp":
|
||||||
|
protocol = fw.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid protocolcol")
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := &uspfilter.PacketBuilder{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: uint16(req.GetSourcePort()),
|
||||||
|
DstPort: uint16(req.GetDestinationPort()),
|
||||||
|
Direction: dir,
|
||||||
|
TCPState: tcpState,
|
||||||
|
ICMPType: uint8(req.GetIcmpType()),
|
||||||
|
ICMPCode: uint8(req.GetIcmpCode()),
|
||||||
|
}
|
||||||
|
trace, err := tracer.TracePacketFromBuilder(builder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("trace packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &proto.TracePacketResponse{}
|
||||||
|
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
stage := &proto.TraceStage{
|
||||||
|
Name: result.Stage.String(),
|
||||||
|
Message: result.Message,
|
||||||
|
Allowed: result.Allowed,
|
||||||
|
}
|
||||||
|
if result.ForwarderAction != nil {
|
||||||
|
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
||||||
|
stage.ForwardingDetails = &details
|
||||||
|
}
|
||||||
|
resp.Stages = append(resp.Stages, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(trace.Results) > 0 {
|
||||||
|
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
@ -21,6 +21,7 @@ import (
|
|||||||
"fyne.io/fyne/v2"
|
"fyne.io/fyne/v2"
|
||||||
"fyne.io/fyne/v2/app"
|
"fyne.io/fyne/v2/app"
|
||||||
"fyne.io/fyne/v2/dialog"
|
"fyne.io/fyne/v2/dialog"
|
||||||
|
"fyne.io/fyne/v2/theme"
|
||||||
"fyne.io/fyne/v2/widget"
|
"fyne.io/fyne/v2/widget"
|
||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
@ -33,6 +34,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/client/ui/event"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
@ -82,7 +84,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
a := app.NewWithID("NetBird")
|
a := app.NewWithID("NetBird")
|
||||||
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
|
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnected))
|
||||||
|
|
||||||
if errorMSG != "" {
|
if errorMSG != "" {
|
||||||
showErrorMSG(errorMSG)
|
showErrorMSG(errorMSG)
|
||||||
@ -90,6 +92,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
|
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
|
||||||
|
settingsChangeChan := make(chan fyne.Settings)
|
||||||
|
a.Settings().AddChangeListener(settingsChangeChan)
|
||||||
|
go func() {
|
||||||
|
for range settingsChangeChan {
|
||||||
|
client.updateIcon()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if showSettings || showRoutes {
|
if showSettings || showRoutes {
|
||||||
a.Run()
|
a.Run()
|
||||||
} else {
|
} else {
|
||||||
@ -106,46 +116,36 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:embed netbird-systemtray-connected.ico
|
//go:embed netbird-systemtray-connected-macos.png
|
||||||
var iconConnectedICO []byte
|
var iconConnectedMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-connected.png
|
//go:embed netbird-systemtray-disconnected-macos.png
|
||||||
var iconConnectedPNG []byte
|
var iconDisconnectedMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-disconnected.ico
|
//go:embed netbird-systemtray-update-disconnected-macos.png
|
||||||
var iconDisconnectedICO []byte
|
var iconUpdateDisconnectedMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-disconnected.png
|
//go:embed netbird-systemtray-update-connected-macos.png
|
||||||
var iconDisconnectedPNG []byte
|
var iconUpdateConnectedMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-disconnected.ico
|
//go:embed netbird-systemtray-connecting-macos.png
|
||||||
var iconUpdateDisconnectedICO []byte
|
var iconConnectingMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-disconnected.png
|
//go:embed netbird-systemtray-error-macos.png
|
||||||
var iconUpdateDisconnectedPNG []byte
|
var iconErrorMacOS []byte
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-connected.ico
|
|
||||||
var iconUpdateConnectedICO []byte
|
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-connected.png
|
|
||||||
var iconUpdateConnectedPNG []byte
|
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-cloud.ico
|
|
||||||
var iconUpdateCloudICO []byte
|
|
||||||
|
|
||||||
//go:embed netbird-systemtray-update-cloud.png
|
|
||||||
var iconUpdateCloudPNG []byte
|
|
||||||
|
|
||||||
type serviceClient struct {
|
type serviceClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
addr string
|
addr string
|
||||||
conn proto.DaemonServiceClient
|
conn proto.DaemonServiceClient
|
||||||
|
|
||||||
|
icAbout []byte
|
||||||
icConnected []byte
|
icConnected []byte
|
||||||
icDisconnected []byte
|
icDisconnected []byte
|
||||||
icUpdateConnected []byte
|
icUpdateConnected []byte
|
||||||
icUpdateDisconnected []byte
|
icUpdateDisconnected []byte
|
||||||
icUpdateCloud []byte
|
icConnecting []byte
|
||||||
|
icError []byte
|
||||||
|
|
||||||
// systray menu items
|
// systray menu items
|
||||||
mStatus *systray.MenuItem
|
mStatus *systray.MenuItem
|
||||||
@ -162,6 +162,7 @@ type serviceClient struct {
|
|||||||
mAllowSSH *systray.MenuItem
|
mAllowSSH *systray.MenuItem
|
||||||
mAutoConnect *systray.MenuItem
|
mAutoConnect *systray.MenuItem
|
||||||
mEnableRosenpass *systray.MenuItem
|
mEnableRosenpass *systray.MenuItem
|
||||||
|
mNotifications *systray.MenuItem
|
||||||
mAdvancedSettings *systray.MenuItem
|
mAdvancedSettings *systray.MenuItem
|
||||||
|
|
||||||
// application with main windows.
|
// application with main windows.
|
||||||
@ -197,6 +198,8 @@ type serviceClient struct {
|
|||||||
isUpdateIconActive bool
|
isUpdateIconActive bool
|
||||||
showRoutes bool
|
showRoutes bool
|
||||||
wRoutes fyne.Window
|
wRoutes fyne.Window
|
||||||
|
|
||||||
|
eventManager *event.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServiceClient instance constructor
|
// newServiceClient instance constructor
|
||||||
@ -214,20 +217,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
|
|||||||
update: version.NewUpdate(),
|
update: version.NewUpdate(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
s.setNewIcons()
|
||||||
s.icConnected = iconConnectedICO
|
|
||||||
s.icDisconnected = iconDisconnectedICO
|
|
||||||
s.icUpdateConnected = iconUpdateConnectedICO
|
|
||||||
s.icUpdateDisconnected = iconUpdateDisconnectedICO
|
|
||||||
s.icUpdateCloud = iconUpdateCloudICO
|
|
||||||
|
|
||||||
} else {
|
|
||||||
s.icConnected = iconConnectedPNG
|
|
||||||
s.icDisconnected = iconDisconnectedPNG
|
|
||||||
s.icUpdateConnected = iconUpdateConnectedPNG
|
|
||||||
s.icUpdateDisconnected = iconUpdateDisconnectedPNG
|
|
||||||
s.icUpdateCloud = iconUpdateCloudPNG
|
|
||||||
}
|
|
||||||
|
|
||||||
if showSettings {
|
if showSettings {
|
||||||
s.showSettingsUI()
|
s.showSettingsUI()
|
||||||
@ -239,6 +229,44 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *serviceClient) setNewIcons() {
|
||||||
|
s.icAbout = iconAbout
|
||||||
|
if s.app.Settings().ThemeVariant() == theme.VariantDark {
|
||||||
|
s.icConnected = iconConnectedDark
|
||||||
|
s.icDisconnected = iconDisconnected
|
||||||
|
s.icUpdateConnected = iconUpdateConnectedDark
|
||||||
|
s.icUpdateDisconnected = iconUpdateDisconnectedDark
|
||||||
|
s.icConnecting = iconConnectingDark
|
||||||
|
s.icError = iconErrorDark
|
||||||
|
} else {
|
||||||
|
s.icConnected = iconConnected
|
||||||
|
s.icDisconnected = iconDisconnected
|
||||||
|
s.icUpdateConnected = iconUpdateConnected
|
||||||
|
s.icUpdateDisconnected = iconUpdateDisconnected
|
||||||
|
s.icConnecting = iconConnecting
|
||||||
|
s.icError = iconError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceClient) updateIcon() {
|
||||||
|
s.setNewIcons()
|
||||||
|
s.updateIndicationLock.Lock()
|
||||||
|
if s.connected {
|
||||||
|
if s.isUpdateIconActive {
|
||||||
|
systray.SetTemplateIcon(iconUpdateConnectedMacOS, s.icUpdateConnected)
|
||||||
|
} else {
|
||||||
|
systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if s.isUpdateIconActive {
|
||||||
|
systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected)
|
||||||
|
} else {
|
||||||
|
systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.updateIndicationLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *serviceClient) showSettingsUI() {
|
func (s *serviceClient) showSettingsUI() {
|
||||||
// add settings window UI elements.
|
// add settings window UI elements.
|
||||||
s.wSettings = s.app.NewWindow("NetBird Settings")
|
s.wSettings = s.app.NewWindow("NetBird Settings")
|
||||||
@ -376,8 +404,10 @@ func (s *serviceClient) login() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) menuUpClick() error {
|
func (s *serviceClient) menuUpClick() error {
|
||||||
|
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
|
||||||
log.Errorf("get client: %v", err)
|
log.Errorf("get client: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -403,10 +433,12 @@ func (s *serviceClient) menuUpClick() error {
|
|||||||
log.Errorf("up service: %v", err)
|
log.Errorf("up service: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) menuDownClick() error {
|
func (s *serviceClient) menuDownClick() error {
|
||||||
|
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get client: %v", err)
|
log.Errorf("get client: %v", err)
|
||||||
@ -458,9 +490,9 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
s.connected = true
|
s.connected = true
|
||||||
s.sendNotification = true
|
s.sendNotification = true
|
||||||
if s.isUpdateIconActive {
|
if s.isUpdateIconActive {
|
||||||
systray.SetIcon(s.icUpdateConnected)
|
systray.SetTemplateIcon(iconUpdateConnectedMacOS, s.icUpdateConnected)
|
||||||
} else {
|
} else {
|
||||||
systray.SetIcon(s.icConnected)
|
systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected)
|
||||||
}
|
}
|
||||||
systray.SetTooltip("NetBird (Connected)")
|
systray.SetTooltip("NetBird (Connected)")
|
||||||
s.mStatus.SetTitle("Connected")
|
s.mStatus.SetTitle("Connected")
|
||||||
@ -482,11 +514,9 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion)
|
s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion)
|
||||||
if !s.isUpdateIconActive {
|
if !s.isUpdateIconActive {
|
||||||
if systrayIconState {
|
if systrayIconState {
|
||||||
systray.SetIcon(s.icConnected)
|
systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected)
|
||||||
s.mAbout.SetIcon(s.icConnected)
|
|
||||||
} else {
|
} else {
|
||||||
systray.SetIcon(s.icDisconnected)
|
systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected)
|
||||||
s.mAbout.SetIcon(s.icDisconnected)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -506,7 +536,6 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
Stop: backoff.Stop,
|
Stop: backoff.Stop,
|
||||||
Clock: backoff.SystemClock,
|
Clock: backoff.SystemClock,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -517,9 +546,9 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
func (s *serviceClient) setDisconnectedStatus() {
|
func (s *serviceClient) setDisconnectedStatus() {
|
||||||
s.connected = false
|
s.connected = false
|
||||||
if s.isUpdateIconActive {
|
if s.isUpdateIconActive {
|
||||||
systray.SetIcon(s.icUpdateDisconnected)
|
systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected)
|
||||||
} else {
|
} else {
|
||||||
systray.SetIcon(s.icDisconnected)
|
systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected)
|
||||||
}
|
}
|
||||||
systray.SetTooltip("NetBird (Disconnected)")
|
systray.SetTooltip("NetBird (Disconnected)")
|
||||||
s.mStatus.SetTitle("Disconnected")
|
s.mStatus.SetTitle("Disconnected")
|
||||||
@ -529,7 +558,7 @@ func (s *serviceClient) setDisconnectedStatus() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) onTrayReady() {
|
func (s *serviceClient) onTrayReady() {
|
||||||
systray.SetIcon(s.icDisconnected)
|
systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected)
|
||||||
systray.SetTooltip("NetBird")
|
systray.SetTooltip("NetBird")
|
||||||
|
|
||||||
// setup systray menu items
|
// setup systray menu items
|
||||||
@ -546,6 +575,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false)
|
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false)
|
||||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false)
|
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false)
|
||||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false)
|
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false)
|
||||||
|
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", "Enable notifications", true)
|
||||||
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application")
|
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application")
|
||||||
s.loadSettings()
|
s.loadSettings()
|
||||||
|
|
||||||
@ -554,7 +584,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
systray.AddSeparator()
|
systray.AddSeparator()
|
||||||
|
|
||||||
s.mAbout = systray.AddMenuItem("About", "About")
|
s.mAbout = systray.AddMenuItem("About", "About")
|
||||||
s.mAbout.SetIcon(s.icDisconnected)
|
s.mAbout.SetIcon(s.icAbout)
|
||||||
versionString := normalizedVersion(version.NetbirdVersion())
|
versionString := normalizedVersion(version.NetbirdVersion())
|
||||||
s.mVersionUI = s.mAbout.AddSubMenuItem(fmt.Sprintf("GUI: %s", versionString), fmt.Sprintf("GUI Version: %s", versionString))
|
s.mVersionUI = s.mAbout.AddSubMenuItem(fmt.Sprintf("GUI: %s", versionString), fmt.Sprintf("GUI Version: %s", versionString))
|
||||||
s.mVersionUI.Disable()
|
s.mVersionUI.Disable()
|
||||||
@ -582,6 +612,10 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
s.eventManager = event.NewManager(s.app, s.addr)
|
||||||
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
|
go s.eventManager.Start(s.ctx)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
for {
|
for {
|
||||||
@ -616,7 +650,6 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
if err := s.updateConfig(); err != nil {
|
if err := s.updateConfig(); err != nil {
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
case <-s.mAutoConnect.ClickedCh:
|
case <-s.mAutoConnect.ClickedCh:
|
||||||
if s.mAutoConnect.Checked() {
|
if s.mAutoConnect.Checked() {
|
||||||
@ -626,7 +659,6 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
if err := s.updateConfig(); err != nil {
|
if err := s.updateConfig(); err != nil {
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
case <-s.mEnableRosenpass.ClickedCh:
|
case <-s.mEnableRosenpass.ClickedCh:
|
||||||
if s.mEnableRosenpass.Checked() {
|
if s.mEnableRosenpass.Checked() {
|
||||||
@ -636,7 +668,6 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
if err := s.updateConfig(); err != nil {
|
if err := s.updateConfig(); err != nil {
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
case <-s.mAdvancedSettings.ClickedCh:
|
case <-s.mAdvancedSettings.ClickedCh:
|
||||||
s.mAdvancedSettings.Disable()
|
s.mAdvancedSettings.Disable()
|
||||||
@ -659,7 +690,20 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
defer s.mRoutes.Enable()
|
defer s.mRoutes.Enable()
|
||||||
s.runSelfCommand("networks", "true")
|
s.runSelfCommand("networks", "true")
|
||||||
}()
|
}()
|
||||||
|
case <-s.mNotifications.ClickedCh:
|
||||||
|
if s.mNotifications.Checked() {
|
||||||
|
s.mNotifications.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mNotifications.Check()
|
||||||
}
|
}
|
||||||
|
if s.eventManager != nil {
|
||||||
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("process connection: %v", err)
|
log.Errorf("process connection: %v", err)
|
||||||
}
|
}
|
||||||
@ -759,8 +803,20 @@ func (s *serviceClient) getSrvConfig() {
|
|||||||
if !cfg.RosenpassEnabled {
|
if !cfg.RosenpassEnabled {
|
||||||
s.sRosenpassPermissive.Disable()
|
s.sRosenpassPermissive.Disable()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.mNotifications == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.DisableNotifications {
|
||||||
|
s.mNotifications.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mNotifications.Check()
|
||||||
|
}
|
||||||
|
if s.eventManager != nil {
|
||||||
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) onUpdateAvailable() {
|
func (s *serviceClient) onUpdateAvailable() {
|
||||||
@ -771,9 +827,9 @@ func (s *serviceClient) onUpdateAvailable() {
|
|||||||
s.isUpdateIconActive = true
|
s.isUpdateIconActive = true
|
||||||
|
|
||||||
if s.connected {
|
if s.connected {
|
||||||
systray.SetIcon(s.icUpdateConnected)
|
systray.SetTemplateIcon(iconUpdateConnectedMacOS, s.icUpdateConnected)
|
||||||
} else {
|
} else {
|
||||||
systray.SetIcon(s.icUpdateDisconnected)
|
systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -825,6 +881,15 @@ func (s *serviceClient) loadSettings() {
|
|||||||
} else {
|
} else {
|
||||||
s.mEnableRosenpass.Uncheck()
|
s.mEnableRosenpass.Uncheck()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.DisableNotifications {
|
||||||
|
s.mNotifications.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mNotifications.Check()
|
||||||
|
}
|
||||||
|
if s.eventManager != nil {
|
||||||
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateConfig updates the configuration parameters
|
// updateConfig updates the configuration parameters
|
||||||
@ -833,12 +898,14 @@ func (s *serviceClient) updateConfig() error {
|
|||||||
disableAutoStart := !s.mAutoConnect.Checked()
|
disableAutoStart := !s.mAutoConnect.Checked()
|
||||||
sshAllowed := s.mAllowSSH.Checked()
|
sshAllowed := s.mAllowSSH.Checked()
|
||||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||||
|
notificationsDisabled := !s.mNotifications.Checked()
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||||
ServerSSHAllowed: &sshAllowed,
|
ServerSSHAllowed: &sshAllowed,
|
||||||
RosenpassEnabled: &rosenpassEnabled,
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
DisableAutoConnect: &disableAutoStart,
|
DisableAutoConnect: &disableAutoStart,
|
||||||
|
DisableNotifications: ¬ificationsDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.restartClient(&loginRequest); err != nil {
|
if err := s.restartClient(&loginRequest); err != nil {
|
||||||
@ -851,17 +918,20 @@ func (s *serviceClient) updateConfig() error {
|
|||||||
|
|
||||||
// restartClient restarts the client connection.
|
// restartClient restarts the client connection.
|
||||||
func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
|
func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
|
||||||
|
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
client, err := s.getSrvClient(failFastTimeout)
|
client, err := s.getSrvClient(failFastTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.Login(s.ctx, loginRequest)
|
_, err = client.Login(ctx, loginRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.Up(s.ctx, &proto.UpRequest{})
|
_, err = client.Up(ctx, &proto.UpRequest{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
151
client/ui/event/event.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2"
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
app fyne.App
|
||||||
|
addr string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(app fyne.App, addr string) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
app: app,
|
||||||
|
addr: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) Start(ctx context.Context) {
|
||||||
|
e.mu.Lock()
|
||||||
|
e.ctx, e.cancel = context.WithCancel(ctx)
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
expBackOff := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: time.Second,
|
||||||
|
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||||
|
Multiplier: backoff.DefaultMultiplier,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 0,
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}, ctx)
|
||||||
|
|
||||||
|
if err := backoff.Retry(e.streamEvents, expBackOff); err != nil {
|
||||||
|
log.Errorf("event stream ended: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) streamEvents() error {
|
||||||
|
e.mu.Lock()
|
||||||
|
ctx := e.ctx
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
client, err := getClient(e.addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := client.SubscribeEvents(ctx, &proto.SubscribeRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to subscribe to events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("subscribed to daemon events")
|
||||||
|
defer func() {
|
||||||
|
log.Info("unsubscribed from daemon events")
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
event, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error receiving event: %w", err)
|
||||||
|
}
|
||||||
|
e.handleEvent(event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) Stop() {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
if e.cancel != nil {
|
||||||
|
e.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) SetNotificationsEnabled(enabled bool) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
e.enabled = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
||||||
|
e.mu.Lock()
|
||||||
|
enabled := e.enabled
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
title := e.getEventTitle(event)
|
||||||
|
e.app.SendNotification(fyne.NewNotification(title, event.UserMessage))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Manager) getEventTitle(event *proto.SystemEvent) string {
|
||||||
|
var prefix string
|
||||||
|
switch event.Severity {
|
||||||
|
case proto.SystemEvent_ERROR, proto.SystemEvent_CRITICAL:
|
||||||
|
prefix = "Error"
|
||||||
|
case proto.SystemEvent_WARNING:
|
||||||
|
prefix = "Warning"
|
||||||
|
default:
|
||||||
|
prefix = "Info"
|
||||||
|
}
|
||||||
|
|
||||||
|
var category string
|
||||||
|
switch event.Category {
|
||||||
|
case proto.SystemEvent_DNS:
|
||||||
|
category = "DNS"
|
||||||
|
case proto.SystemEvent_NETWORK:
|
||||||
|
category = "Network"
|
||||||
|
case proto.SystemEvent_AUTHENTICATION:
|
||||||
|
category = "Authentication"
|
||||||
|
case proto.SystemEvent_CONNECTIVITY:
|
||||||
|
category = "Connectivity"
|
||||||
|
default:
|
||||||
|
category = "System"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s: %s", prefix, category)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClient(addr string) (proto.DaemonServiceClient, error) {
|
||||||
|
conn, err := grpc.NewClient(
|
||||||
|
strings.TrimPrefix(addr, "tcp://"),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithUserAgent(system.GetDesktopUIUserAgent()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return proto.NewDaemonServiceClient(conn), nil
|
||||||
|
}
|
43
client/ui/icons.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
//go:build !(linux && 386) && !windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed netbird.png
|
||||||
|
var iconAbout []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connected.png
|
||||||
|
var iconConnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connected-dark.png
|
||||||
|
var iconConnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-disconnected.png
|
||||||
|
var iconDisconnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-disconnected.png
|
||||||
|
var iconUpdateDisconnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-disconnected-dark.png
|
||||||
|
var iconUpdateDisconnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-connected.png
|
||||||
|
var iconUpdateConnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-connected-dark.png
|
||||||
|
var iconUpdateConnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connecting.png
|
||||||
|
var iconConnecting []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connecting-dark.png
|
||||||
|
var iconConnectingDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-error.png
|
||||||
|
var iconError []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-error-dark.png
|
||||||
|
var iconErrorDark []byte
|
41
client/ui/icons_windows.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed netbird.ico
|
||||||
|
var iconAbout []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connected.ico
|
||||||
|
var iconConnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connected-dark.ico
|
||||||
|
var iconConnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-disconnected.ico
|
||||||
|
var iconDisconnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-disconnected.ico
|
||||||
|
var iconUpdateDisconnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-disconnected-dark.ico
|
||||||
|
var iconUpdateDisconnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-connected.ico
|
||||||
|
var iconUpdateConnected []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-update-connected-dark.ico
|
||||||
|
var iconUpdateConnectedDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connecting.ico
|
||||||
|
var iconConnecting []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-connecting-dark.ico
|
||||||
|
var iconConnectingDark []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-error.ico
|
||||||
|
var iconError []byte
|
||||||
|
|
||||||
|
//go:embed netbird-systemtray-error-dark.ico
|
||||||
|
var iconErrorDark []byte
|
BIN
client/ui/netbird-systemtray-connected-dark.ico
Normal file
After Width: | Height: | Size: 103 KiB |
BIN
client/ui/netbird-systemtray-connected-dark.png
Normal file
After Width: | Height: | Size: 5.1 KiB |
BIN
client/ui/netbird-systemtray-connected-macos.png
Normal file
After Width: | Height: | Size: 3.8 KiB |
Before Width: | Height: | Size: 5.0 KiB After Width: | Height: | Size: 103 KiB |
Before Width: | Height: | Size: 8.9 KiB After Width: | Height: | Size: 5.2 KiB |