mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-22 05:49:12 +01:00
Merge branch 'main' into refactor-get-account-by-token
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
commit
901d283114
9
.github/workflows/golang-test-linux.yml
vendored
9
.github/workflows/golang-test-linux.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
@ -79,9 +79,6 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
@ -98,7 +95,7 @@ jobs:
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
@ -106,7 +103,7 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.15"
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
|
@ -3,7 +3,6 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
@ -11,10 +10,11 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var errFw error
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
|
||||
if err = fm.Init(stateManager); err != nil {
|
||||
return nil, fmt.Errorf("init firewall: %s", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
return nbiptables.Create(iface)
|
||||
case NFTABLES:
|
||||
log.Info("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
return nbnftables.Create(iface)
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
var errUsp error
|
||||
if errFw == nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
if errUsp != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||
return nil, errUsp
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
if errUsp != nil {
|
||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||
}
|
||||
|
||||
if errFw != nil {
|
||||
return nil, errFw
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@ -22,6 +23,8 @@ const (
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
@ -32,9 +35,11 @@ type aclManager struct {
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
m.stateManager = stateManager
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
err = m.createDefaultChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := m.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
return err
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
func (m *aclManager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
|
@ -8,10 +8,13 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@ -33,10 +36,10 @@ type iFaceMapper interface {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
return nil, fmt.Errorf("init iptables: %w", err)
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, iptablesClient, wgIface)
|
||||
m.router, err = newRouter(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(stateManager); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclMgr.init(stateManager); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
errAcl := m.aclMgr.Reset()
|
||||
if errAcl != nil {
|
||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
errMgr := m.router.Reset()
|
||||
if errMgr != nil {
|
||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
||||
return errMgr
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
return errAcl
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
|
@ -1,7 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
@ -3,7 +3,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
@ -18,6 +17,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -48,28 +48,31 @@ type routeFilteringRuleParams struct {
|
||||
SetName string
|
||||
}
|
||||
|
||||
type routeRules map[string][]string
|
||||
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
rules routeRules
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("create containers for route: %s", err)
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
return r, err
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
r.updateState()
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -278,8 +296,13 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@ -294,6 +317,8 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@ -431,6 +456,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
|
@ -3,7 +3,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
||||
r, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
|
@ -1,14 +1,16 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) ipList {
|
||||
func newIpList(ip string) *ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return ipList{
|
||||
return &ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{
|
||||
IPs: s.ips,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]ipList // ipsetName -> ruleset
|
||||
ipsets map[string]*ipList
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]ipList),
|
||||
ipsets: make(map[string]*ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
s.ipsets[ipsetName] = ipList{}
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{
|
||||
IPSets: s.ipsets,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
return nil
|
||||
}
|
||||
|
70
client/firewall/iptables/state_linux.go
Normal file
70
client/firewall/iptables/state_linux.go
Normal file
@ -0,0 +1,70 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
sync.Mutex
|
||||
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "iptables_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
ipt, err := Create(s.InterfaceState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iptables manager: %w", err)
|
||||
}
|
||||
|
||||
if s.RouteRules != nil {
|
||||
ipt.router.rules = s.RouteRules
|
||||
}
|
||||
if s.RouteIPsetCounter != nil {
|
||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
}
|
||||
|
||||
if s.ACLEntries != nil {
|
||||
ipt.aclMgr.entries = s.ACLEntries
|
||||
}
|
||||
if s.ACLIPsetStore != nil {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
if err := ipt.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -10,6 +10,8 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -52,6 +54,8 @@ const (
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
Init(stateManager *statemanager.Manager) error
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
@ -91,7 +95,7 @@ type Manager interface {
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@ -56,13 +55,6 @@ type AclManager struct {
|
||||
rules map[string]*Rule
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||
}
|
||||
|
||||
m := &AclManager{
|
||||
return &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
err = m.createDefaultChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
m.workTable = workTable
|
||||
return m.createDefaultChains()
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
|
@ -14,6 +14,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -24,6 +26,13 @@ const (
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
@ -35,30 +44,68 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
m.router, err = newRouter(context, workTable, wgIface)
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// persist early
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
@ -183,68 +230,84 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
oldLegacy := m.router.legacyManagement
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
m.router.legacyManagement = isLegacy
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
return fmt.Errorf("delete state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset forward rules: %v", err)
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
|
@ -1,7 +1,6 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
@ -2,7 +2,6 @@ package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -40,8 +39,6 @@ var (
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
@ -54,12 +51,8 @@ type router struct {
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("load filter table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.removeAcceptForwardRules()
|
||||
if err != nil {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *router) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
return r, err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
@ -553,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
@ -3,7 +3,6 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func(r *router) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
|
1
client/firewall/nftables/state.go
Normal file
1
client/firewall/nftables/state.go
Normal file
@ -0,0 +1 @@
|
||||
package nftables
|
47
client/firewall/nftables/state_linux.go
Normal file
47
client/firewall/nftables/state_linux.go
Normal file
@ -0,0 +1,47 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "nftables_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
nft, err := Create(s.InterfaceState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
if err := nft.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset nftables manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -2,8 +2,10 @@
|
||||
|
||||
package uspfilter
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset()
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type action string
|
||||
@ -17,7 +19,7 @@ const (
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
}
|
||||
@ -232,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
||||
return nil
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
|
@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Reset()
|
||||
err = m.Reset(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Reset(); err != nil {
|
||||
if err = m.Reset(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
@ -1,142 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
}
|
||||
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
|
||||
transportNet transport.Net
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
|
||||
filterFn FilterFn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
ib := &ICEBind{
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
ib,
|
||||
}
|
||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||
return ib
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
if s.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for i := range buffers {
|
||||
if !stun.IsMessage(buffers[i]) {
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
||||
if err != nil {
|
||||
buffers[i] = []byte{}
|
||||
return true, err
|
||||
}
|
||||
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
msg := &stun.Message{
|
||||
Raw: raw,
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
5
client/iface/bind/endpoint.go
Normal file
5
client/iface/bind/endpoint.go
Normal file
@ -0,0 +1,5 @@
|
||||
package bind
|
||||
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
type Endpoint = wgConn.StdNetEndpoint
|
275
client/iface/bind/ice_bind.go
Normal file
275
client/iface/bind/ice_bind.go
Normal file
@ -0,0 +1,275 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
// 1. filter out STUN messages and handle them
|
||||
// 2. forward the received packets to the WireGuard interface from the relayed connection
|
||||
//
|
||||
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
|
||||
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
|
||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
RecvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
ib,
|
||||
}
|
||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||
return ib
|
||||
}
|
||||
|
||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
s.closed = false
|
||||
s.closedChanMu.Lock()
|
||||
s.closedChan = make(chan struct{})
|
||||
s.closedChanMu.Unlock()
|
||||
fns, port, err := s.StdNetBind.Open(uport)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
fns = append(fns, s.receiveRelayed)
|
||||
return fns, port, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) Close() error {
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
close(s.closedChan)
|
||||
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
if s.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
delete(b.endpoints, fakeAddr)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
b.endpointsMu.Lock()
|
||||
conn, ok := b.endpoints[ep.DstIP()]
|
||||
b.endpointsMu.Unlock()
|
||||
if !ok {
|
||||
return b.StdNetBind.Send(bufs, ep)
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for i := range buffers {
|
||||
if !stun.IsMessage(buffers[i]) {
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
||||
if err != nil {
|
||||
buffers[i] = []byte{}
|
||||
return true, err
|
||||
}
|
||||
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
msg := &stun.Message{
|
||||
Raw: raw,
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
|
||||
// WireGuard. Critical part is do not block if the Closed() has been called.
|
||||
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
c.closedChanMu.RLock()
|
||||
defer c.closedChanMu.RUnlock()
|
||||
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.RecvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
copy(buffs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
@ -5,7 +5,6 @@ package device
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@ -31,13 +30,13 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
|
||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@ -29,14 +28,14 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,6 @@ package device
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@ -30,13 +29,13 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
||||
iceBind: iceBind,
|
||||
tunFd: tunFd,
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ package device
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@ -31,7 +30,7 @@ type TunNetstackDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@ -30,7 +29,7 @@ type USPDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
checkUser()
|
||||
@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn)}
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@ -32,14 +31,14 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,12 +6,16 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -22,14 +26,35 @@ const (
|
||||
|
||||
type WGAddress = device.WGAddress
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
Free() error
|
||||
}
|
||||
|
||||
type WGIFaceOpts struct {
|
||||
IFaceName string
|
||||
Address string
|
||||
WGPort int
|
||||
WGPrivKey string
|
||||
MTU int
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn bind.FilterFn
|
||||
}
|
||||
|
||||
// WGIface represents an interface instance
|
||||
type WGIface struct {
|
||||
tun WGTunDevice
|
||||
userspaceBind bool
|
||||
mu sync.Mutex
|
||||
|
||||
configurer device.WGConfigurer
|
||||
filter device.PacketFilter
|
||||
configurer device.WGConfigurer
|
||||
filter device.PacketFilter
|
||||
wgProxyFactory wgProxyFactory
|
||||
}
|
||||
|
||||
func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||
@ -124,22 +149,26 @@ func (w *WGIface) Close() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
err := w.tun.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
|
||||
var result *multierror.Error
|
||||
|
||||
if err := w.wgProxyFactory.Free(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||
}
|
||||
|
||||
err = w.waitUntilRemoved()
|
||||
if err != nil {
|
||||
if err := w.tun.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
if err := w.waitUntilRemoved(); err != nil {
|
||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||
err = w.Destroy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
|
||||
if err := w.Destroy(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
log.Infof("interface %s successfully removed", w.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetFilter sets packet filters for the userspace implementation
|
||||
|
@ -1,43 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
|
||||
userspaceBind: true,
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
cfgr, err := w.tun.Create(routes, dns, searchDomains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.configurer = cfgr
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create this function make sense on mobile only
|
||||
func (w *WGIface) Create() error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
@ -2,6 +2,8 @@
|
||||
|
||||
package iface
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// this function is different on Android
|
||||
@ -17,3 +19,8 @@ func (w *WGIface) Create() error {
|
||||
w.configurer = cfgr
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
24
client/iface/iface_create_android.go
Normal file
24
client/iface/iface_create_android.go
Normal file
@ -0,0 +1,24 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
cfgr, err := w.tun.Create(routes, dns, searchDomains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.configurer = cfgr
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create this function make sense on mobile only
|
||||
func (w *WGIface) Create() error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
@ -7,39 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
|
||||
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// this function is different on Android
|
||||
@ -65,3 +34,8 @@ func (w *WGIface) Create() error {
|
||||
|
||||
return backoff.Retry(operation, backOff)
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
10
client/iface/iface_guid_windows.go
Normal file
10
client/iface/iface_guid_windows.go
Normal file
@ -0,0 +1,10 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
||||
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
|
||||
userspaceBind: true,
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
@ -30,6 +31,7 @@ type MockWGIface struct {
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
24
client/iface/iface_new_android.go
Normal file
24
client/iface/iface_new_android.go
Normal file
@ -0,0 +1,24 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
34
client/iface/iface_new_darwin.go
Normal file
34
client/iface/iface_new_darwin.go
Normal file
@ -0,0 +1,34 @@
|
||||
//go:build !ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
26
client/iface/iface_new_ios.go
Normal file
26
client/iface/iface_new_ios.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
45
client/iface/iface_new_unix.go
Normal file
45
client/iface/iface_new_unix.go
Normal file
@ -0,0 +1,45 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.WireGuardModuleIsLoaded() {
|
||||
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
|
||||
return wgIFace, nil
|
||||
}
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
32
client/iface/iface_new_windows.go
Normal file
32
client/iface/iface_new_windows.go
Normal file
@ -0,0 +1,32 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
}
|
@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: addr,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -153,7 +171,16 @@ func Test_Close(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
|
||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
|
||||
optsPeer1 := WGIFaceOpts{
|
||||
IFaceName: peer1ifaceName,
|
||||
Address: peer1wgIP,
|
||||
WGPort: peer1wgPort,
|
||||
WGPrivKey: peer1Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
iface1, err := NewWGIFace(optsPeer1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort))
|
||||
localIP, err := getLocalIP()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
|
||||
|
||||
optsPeer2 := WGIFaceOpts{
|
||||
IFaceName: peer2ifaceName,
|
||||
Address: peer2wgIP,
|
||||
WGPort: peer2wgPort,
|
||||
WGPrivKey: peer2Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface2, err := NewWGIFace(optsPeer2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort))
|
||||
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
}
|
||||
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
||||
}
|
||||
|
||||
func getLocalIP() (string, error) {
|
||||
// Get all interfaces
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ipNet.IP.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipNet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
return ipNet.IP.String(), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no local IP found")
|
||||
}
|
||||
|
@ -1,49 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
// move the kernel/usp/netstack preference evaluation to upper layer
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
||||
wgIFace.userspaceBind = true
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.WireGuardModuleIsLoaded() {
|
||||
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
|
||||
wgIFace.userspaceBind = false
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if !device.ModuleTunIsLoaded() {
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
|
||||
wgIFace.userspaceBind = true
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
||||
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
||||
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
@ -22,6 +23,7 @@ type IWGIface interface {
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
@ -20,6 +21,7 @@ type IWGIface interface {
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
|
137
client/iface/wgproxy/bind/proxy.go
Normal file
137
client/iface/wgproxy/bind/proxy.go
Normal file
@ -0,0 +1,137 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
)
|
||||
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
wgAddr *net.UDPAddr
|
||||
wgEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.wgAddr = addr
|
||||
p.wgEndpoint = addrToEndpoint(addr)
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return err
|
||||
|
||||
}
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgAddr
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
// Start the proxy only once
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
return p.remoteConn.Close()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("failed to close remote conn: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
@ -5,9 +5,9 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
var (
|
||||
portRangeStart = 3128
|
||||
portRangeEnd = 3228
|
||||
portRangeEnd = portRangeStart + 100
|
||||
)
|
||||
|
||||
type portLookup struct {
|
@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
|
||||
func Test_portLookup_on_allocated(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
|
||||
portRangeStart = 4128
|
||||
portRangeEnd = portRangeStart + 100
|
||||
|
||||
allocatedPort, err := allocatePort(portRangeStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error {
|
||||
p.ctxCancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if p.conn != nil { // p.conn will be nil if we have failed to listen
|
||||
if p.conn != nil {
|
||||
if err := p.conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
@ -28,7 +28,7 @@ type ProxyWrapper struct {
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
49
client/iface/wgproxy/factory_kernel.go
Normal file
49
client/iface/wgproxy/factory_kernel.go
Normal file
@ -0,0 +1,49 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
|
||||
ebpfProxy *ebpf.WGEBPFProxy
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
if w.ebpfProxy == nil {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||
}
|
||||
|
||||
return &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
29
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
29
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
@ -0,0 +1,29 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// KernelFactory todo: check eBPF support on FreeBSD
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
return nil
|
||||
}
|
30
client/iface/wgproxy/factory_usp.go
Normal file
30
client/iface/wgproxy/factory_usp.go
Normal file
@ -0,0 +1,30 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
type USPFactory struct {
|
||||
bind *bind.ICEBind
|
||||
}
|
||||
|
||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||
f := &USPFactory{
|
||||
bind: iceBind,
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *USPFactory) GetProxy() Proxy {
|
||||
return &proxyBind.ProxyBind{
|
||||
Bind: w.bind,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
15
client/iface/wgproxy/proxy.go
Normal file
15
client/iface/wgproxy/proxy.go
Normal file
@ -0,0 +1,15 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
CloseConn() error
|
||||
}
|
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
)
|
||||
|
||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
t.Skip("Skipping test as it requires root privileges")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "ebpf proxy",
|
||||
proxy: &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -11,8 +11,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: usp.NewWGUserSpaceProxy(51830),
|
||||
proxy: udpProxy.NewWGUDPProxy(51830),
|
||||
},
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
@ -1,19 +1,21 @@
|
||||
package usp
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
// WGUDPProxy proxies
|
||||
type WGUDPProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
remoteConn net.Conn
|
||||
@ -28,10 +30,10 @@ type WGUserSpaceProxy struct {
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
return p
|
||||
@ -42,7 +44,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
@ -57,7 +59,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
@ -66,7 +68,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUserSpaceProxy) Work() {
|
||||
func (p *WGUDPProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
@ -83,7 +85,7 @@ func (p *WGUserSpaceProxy) Work() {
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUserSpaceProxy) Pause() {
|
||||
func (p *WGUDPProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
@ -94,14 +96,14 @@ func (p *WGUserSpaceProxy) Pause() {
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
func (p *WGUDPProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) close() error {
|
||||
func (p *WGUDPProxy) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@ -121,11 +123,11 @@ func (p *WGUserSpaceProxy) close() error {
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
return errors.FormatErrorOrNil(result)
|
||||
return cerrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
@ -157,21 +159,19 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
n, err := p.remoteConnRead(ctx, buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -193,3 +193,15 @@ func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
|
||||
n, err = p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
@ -3,6 +3,7 @@ package acl
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -10,14 +11,18 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||
var newRouteRules = make(map[id.RuleID]struct{})
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply route ACL: %w", err)
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
||||
} else {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Clean up old firewall rules
|
||||
for id := range d.routeRules {
|
||||
if _, ok := newRouteRules[id]; !ok {
|
||||
if _, exists := newRouteRules[id]; !exists {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
||||
continue
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||
}
|
||||
delete(d.routeRules, id)
|
||||
// implicitly deleted from the map
|
||||
}
|
||||
}
|
||||
|
||||
d.routeRules = newRouteRules
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", fmt.Errorf("source ranges is empty")
|
||||
return "", ErrSourceRangesEmpty
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
|
@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
|
@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
) error {
|
||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
}
|
||||
|
||||
@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(
|
||||
mobileDependency MobileDependency,
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
@ -117,12 +110,6 @@ func (c *ConnectClient) run(
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
|
||||
log.Errorf("checking unclean shutdown error: %s", err)
|
||||
}
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@ -358,7 +345,11 @@ func (c *ConnectClient) Stop() error {
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
return c.engine.Stop()
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) isContextCancelled() bool {
|
||||
|
@ -1,6 +1,5 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
)
|
||||
|
@ -3,6 +3,5 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
)
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -20,7 +22,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type repairConfFn func([]string, string, *resolvConf) error
|
||||
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
|
||||
|
||||
type repair struct {
|
||||
operationFile string
|
||||
@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
|
||||
if f.inotify != nil {
|
||||
return
|
||||
}
|
||||
@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
|
||||
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(operationFile, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||
|
||||
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||
if err != nil {
|
||||
@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(tmpLink, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||
|
||||
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||
if err != nil {
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -36,7 +38,7 @@ type fileConfigurator struct {
|
||||
nbNameserverIP string
|
||||
}
|
||||
|
||||
func newFileConfigurator() (hostManager, error) {
|
||||
func newFileConfigurator() (*fileConfigurator, error) {
|
||||
fc := &fileConfigurator{}
|
||||
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||
return fc, nil
|
||||
@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
backupFileExist := f.isBackupFileExist()
|
||||
if !config.RouteAll {
|
||||
if backupFileExist {
|
||||
@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
|
||||
f.repair.stopWatchFileChanges()
|
||||
|
||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
|
||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||
|
||||
@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
|
||||
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||
|
||||
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
||||
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -5,14 +5,14 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig) error
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
@ -35,15 +35,15 @@ type DomainConfig struct {
|
||||
}
|
||||
|
||||
type mockHostConfigurator struct {
|
||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
||||
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if m.applyDNSConfigFunc != nil {
|
||||
return m.applyDNSConfigFunc(config)
|
||||
return m.applyDNSConfigFunc(config, stateManager)
|
||||
}
|
||||
return fmt.Errorf("method applyDNSSettings is not implemented")
|
||||
}
|
||||
@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
if m.restoreUncleanShutdownDNSFunc != nil {
|
||||
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
||||
}
|
||||
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
||||
}
|
||||
|
||||
func newNoopHostMocker() hostManager {
|
||||
return &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
|
||||
restoreHostDNSFunc: func() error { return nil },
|
||||
supportCustomPortFunc: func() bool { return true },
|
||||
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||
|
@ -1,15 +1,17 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager() (hostManager, error) {
|
||||
func newHostManager() (*androidHostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
|
||||
func (a androidHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -8,12 +8,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -37,7 +38,7 @@ type systemConfigurator struct {
|
||||
systemDNSSettings SystemDNSSettings
|
||||
}
|
||||
|
||||
func newHostManager() (hostManager, error) {
|
||||
func newHostManager() (*systemConfigurator, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
}
|
||||
|
@ -3,9 +3,10 @@ package dns
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type iosHostManager struct {
|
||||
@ -13,13 +14,13 @@ type iosHostManager struct {
|
||||
config HostDNSConfig
|
||||
}
|
||||
|
||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
||||
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||
return &iosHostManager{
|
||||
dnsManager: dnsManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
|
||||
func (a iosHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -4,9 +4,9 @@ package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@ -21,27 +21,8 @@ const (
|
||||
resolvConfManager
|
||||
)
|
||||
|
||||
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newOsManagerType(osManager string) (osManagerType, error) {
|
||||
switch osManager {
|
||||
case "netbird":
|
||||
return fileManager, nil
|
||||
case "file":
|
||||
return netbirdManager, nil
|
||||
case "networkManager":
|
||||
return networkManager, nil
|
||||
case "systemd":
|
||||
return systemdManager, nil
|
||||
case "resolvconf":
|
||||
return resolvConfManager, nil
|
||||
default:
|
||||
return 0, ErrUnknownOsManagerType
|
||||
}
|
||||
}
|
||||
|
||||
func (t osManagerType) String() string {
|
||||
switch t {
|
||||
case netbirdManager:
|
||||
@ -59,6 +40,11 @@ func (t osManagerType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type restoreHostManager interface {
|
||||
hostManager
|
||||
restoreUncleanShutdownDNS(*netip.Addr) error
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
}
|
||||
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||
switch osManager {
|
||||
case networkManager:
|
||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||
|
@ -3,11 +3,12 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -31,7 +32,7 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
return newHostManagerWithGuid(guid)
|
||||
}
|
||||
|
||||
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
||||
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
||||
return ®istryConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
if config.RouteAll {
|
||||
err = r.addDNSSetupForAll(config.ServerIP)
|
||||
@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||
@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||
@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||
|
||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||
// The file content itself is not important for network-manager restoration
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
state := &ShutdownState{
|
||||
ManagerType: networkManager,
|
||||
WgIface: n.ifaceName,
|
||||
}
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||
@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("delete connection settings: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
@ -22,7 +24,7 @@ type resolvconf struct {
|
||||
}
|
||||
|
||||
// supported "openresolv" only
|
||||
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
||||
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
||||
resolvConfEntries, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||
@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
if !config.RouteAll {
|
||||
err = r.restoreHostDNS()
|
||||
@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||
options)
|
||||
|
||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
state := &ShutdownState{
|
||||
ManagerType: resolvConfManager,
|
||||
WgIface: r.ifaceName,
|
||||
}
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
err = r.applyConfig(buf)
|
||||
@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
|
||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
cmd.Stdin = &content
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
@ -14,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@ -63,6 +65,7 @@ type DefaultServer struct {
|
||||
iosDnsManager IosDnsManager
|
||||
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
@ -77,12 +80,7 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
) (*DefaultServer, error) {
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@ -99,7 +97,7 @@ func NewDefaultServer(
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream(
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@ -130,12 +128,12 @@ func NewDefaultServerIos(
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
}
|
||||
|
||||
@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
s.stateManager.RegisterState(&ShutdownState{})
|
||||
s.hostManager, err = s.initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() {
|
||||
s.ctxCancel()
|
||||
|
||||
if s.hostManager != nil {
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if err := s.hostManager.restoreHostDNS(); err != nil {
|
||||
log.Error("failed to restore host DNS settings: ", err)
|
||||
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to delete shutdown dns state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
hostUpdate.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
}
|
||||
@ -521,10 +529,17 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
}
|
||||
@ -551,7 +566,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.currentConfig.RouteAll = true
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@ -267,7 +268,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -281,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -345,7 +356,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.1/32",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@ -382,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@ -477,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@ -536,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
@ -552,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
@ -803,7 +823,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
@ -1,5 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||
return newHostManager(s.wgInterface)
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
|
||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@ -38,6 +39,7 @@ const (
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface: %w", err)
|
||||
@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
|
||||
return &systemdDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||
@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||
// The file content itself is not important for systemd restoration
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
state := &ShutdownState{
|
||||
ManagerType: systemdManager,
|
||||
WgIface: s.ifaceName,
|
||||
}
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||
@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return s.flushCaches()
|
||||
}
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
package dns
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
return nil
|
||||
}
|
@ -3,57 +3,25 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
|
||||
type ShutdownState struct {
|
||||
}
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// no file -> clean shutdown
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager, err := newHostManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createUncleanShutdownIndicator() error {
|
||||
dir := filepath.Dir(fileUncleanShutdownFileLocation)
|
||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
|
||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUncleanShutdownIndicator() error {
|
||||
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1,5 +0,0 @@
|
||||
package dns
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
return nil
|
||||
}
|
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
@ -0,0 +1,14 @@
|
||||
//go:build ios || android
|
||||
|
||||
package dns
|
||||
|
||||
type ShutdownState struct {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
return nil
|
||||
}
|
@ -3,66 +3,44 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func CheckUncleanShutdown(wgIface string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// no file -> clean shutdown
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("state: %w", err)
|
||||
}
|
||||
}
|
||||
type ShutdownState struct {
|
||||
ManagerType osManagerType
|
||||
DNSAddress netip.Addr
|
||||
WgIface string
|
||||
}
|
||||
|
||||
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||
}
|
||||
|
||||
managerFields := strings.Split(string(managerData), ",")
|
||||
if len(managerFields) < 2 {
|
||||
return errors.New("split manager data: insufficient number of fields")
|
||||
}
|
||||
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
|
||||
|
||||
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
|
||||
}
|
||||
|
||||
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
|
||||
|
||||
// determine os manager type, so we can invoke the respective restore action
|
||||
osManagerType, err := newOsManagerType(osManagerTypeStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("detect previous host manager: %w", err)
|
||||
}
|
||||
|
||||
manager, err := newHostManagerFromType(wgIface, osManagerType)
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create previous host manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
|
||||
// TODO: move file contents to state manager
|
||||
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
|
||||
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||
@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
|
||||
return fmt.Errorf("create %s: %w", sourcePath, err)
|
||||
}
|
||||
|
||||
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
|
||||
|
||||
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
|
||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUncleanShutdownIndicator() error {
|
||||
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
|
||||
}
|
||||
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||
state := &ShutdownState{
|
||||
ManagerType: fileManager,
|
||||
DNSAddress: dnsAddress,
|
||||
}
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
return fmt.Errorf("update state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,75 +1,26 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
netbirdProgramDataLocation = "Netbird"
|
||||
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
|
||||
)
|
||||
type ShutdownState struct {
|
||||
Guid string
|
||||
}
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
file := getUncleanShutdownFile()
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
if _, err := os.Stat(file); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// no file -> clean shutdown
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
|
||||
|
||||
guid, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", file, err)
|
||||
}
|
||||
|
||||
manager, err := newHostManagerWithGuid(string(guid))
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager, err := newHostManagerWithGuid(s.Guid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createUncleanShutdownIndicator(guid string) error {
|
||||
file := getUncleanShutdownFile()
|
||||
|
||||
dir := filepath.Dir(file)
|
||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
|
||||
return fmt.Errorf("create %s: %w", file, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUncleanShutdownIndicator() error {
|
||||
file := getUncleanShutdownFile()
|
||||
|
||||
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("remove %s: %w", file, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUncleanShutdownFile() string {
|
||||
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
|
||||
}
|
||||
|
@ -23,19 +23,21 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -141,8 +143,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface iface.IWGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgInterface iface.IWGIface
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
@ -168,6 +169,8 @@ type Engine struct {
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -215,7 +218,7 @@ func NewEngineWithProbes(
|
||||
probes *ProbeHolder,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
@ -234,6 +237,11 @@ func NewEngineWithProbes(
|
||||
probes: probes,
|
||||
checks: checks,
|
||||
}
|
||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||
engine.stateManager = statemanager.New(path)
|
||||
}
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
@ -255,7 +263,11 @@ func (e *Engine) Stop() error {
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
err := e.removeAllPeers()
|
||||
@ -277,6 +289,17 @@ func (e *Engine) Stop() error {
|
||||
|
||||
e.close()
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -299,9 +322,6 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@ -319,6 +339,8 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
e.stateManager.Start()
|
||||
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
if err != nil {
|
||||
e.close()
|
||||
@ -327,7 +349,7 @@ func (e *Engine) Start() error {
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
@ -344,7 +366,7 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
}
|
||||
@ -374,6 +396,18 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
iceCfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveProbeEvents()
|
||||
@ -956,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
ICEConfig: peer.ICEConfig{
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
@ -966,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
},
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1117,12 +1151,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if e.wgProxyFactory != nil {
|
||||
if err := e.wgProxyFactory.Free(); err != nil {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
@ -1139,7 +1167,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Reset()
|
||||
err := e.firewall.Reset(e.stateManager)
|
||||
if err != nil {
|
||||
log.Warnf("failed to reset firewall: %s", err)
|
||||
}
|
||||
@ -1167,21 +1195,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
var mArgs *device.MobileIFaceArguments
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: e.config.WgIfaceName,
|
||||
Address: e.config.WgAddr,
|
||||
WGPort: e.config.WgPort,
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||
TunAdapter: e.mobileDep.TunAdapter,
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
case "ios":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
||||
return iface.NewWGIFace(opts)
|
||||
}
|
||||
|
||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
@ -1222,10 +1258,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||
return nil, dnsServer, nil
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return nil, dnsServer, nil
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@ -602,7 +605,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
WGPort: engine.config.WgPort,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@ -774,7 +786,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
|
@ -10,15 +10,16 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/pion/ice/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@ -32,8 +33,6 @@ const (
|
||||
connPriorityRelay ConnPriority = 1
|
||||
connPriorityICETurn ConnPriority = 1
|
||||
connPriorityICEP2P ConnPriority = 2
|
||||
|
||||
reconnectMaxElapsedTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
@ -63,7 +62,7 @@ type ConnConfig struct {
|
||||
RosenpassAddr string
|
||||
|
||||
// ICEConfig ICE protocol configuration
|
||||
ICEConfig ICEConfig
|
||||
ICEConfig icemaker.Config
|
||||
}
|
||||
|
||||
type WorkerCallbacks struct {
|
||||
@ -81,11 +80,10 @@ type Conn struct {
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
allowedIPsIP string
|
||||
allowedIP net.IP
|
||||
allowedNet string
|
||||
handshaker *Handshaker
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
@ -107,17 +105,13 @@ type Conn struct {
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
|
||||
// for reconnection operations
|
||||
iCEDisconnected chan bool
|
||||
relayDisconnected chan bool
|
||||
connMonitor *ConnMonitor
|
||||
reconnectCh <-chan struct{}
|
||||
guard *guard.Guard
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
|
||||
_, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
|
||||
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return nil, err
|
||||
@ -132,26 +126,14 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
signaler: signaler,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
relayManager: relayManager,
|
||||
allowedIPsIP: allowedIPsIP.String(),
|
||||
allowedIP: allowedIP,
|
||||
allowedNet: allowedNet.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
|
||||
iCEDisconnected: make(chan bool, 1),
|
||||
relayDisconnected: make(chan bool, 1),
|
||||
}
|
||||
|
||||
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
|
||||
signaler,
|
||||
iFaceDiscover,
|
||||
config,
|
||||
conn.relayDisconnected,
|
||||
conn.iCEDisconnected,
|
||||
)
|
||||
|
||||
rFns := WorkerRelayCallbacks{
|
||||
OnConnReady: conn.relayConnectionIsReady,
|
||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
||||
@ -162,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
OnStatusChanged: conn.onWorkerICEStateDisconnected,
|
||||
}
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns)
|
||||
ctrl := isController(config)
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
|
||||
@ -177,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
|
||||
|
||||
go conn.handshaker.Listen()
|
||||
|
||||
return conn, nil
|
||||
@ -187,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
// be used.
|
||||
func (conn *Conn) Open() {
|
||||
conn.log.Debugf("open connection to peer")
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
conn.opened = true
|
||||
@ -203,24 +189,19 @@ func (conn *Conn) Open() {
|
||||
conn.log.Warnf("error while updating the state err: %v", err)
|
||||
}
|
||||
|
||||
go conn.startHandshakeAndReconnect()
|
||||
go conn.startHandshakeAndReconnect(conn.ctx)
|
||||
}
|
||||
|
||||
func (conn *Conn) startHandshakeAndReconnect() {
|
||||
conn.waitInitialRandomSleepTime()
|
||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
conn.waitInitialRandomSleepTime(ctx)
|
||||
|
||||
err := conn.handshaker.sendOffer()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
go conn.connMonitor.Start(conn.ctx)
|
||||
|
||||
if conn.workerRelay.IsController() {
|
||||
conn.reconnectLoopWithRetry()
|
||||
} else {
|
||||
conn.reconnectLoopForOnDisconnectedEvent()
|
||||
}
|
||||
go conn.guard.Start(ctx)
|
||||
go conn.listenGuardEvent(ctx)
|
||||
}
|
||||
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
@ -319,104 +300,6 @@ func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
func (conn *Conn) reconnectLoopWithRetry() {
|
||||
// Give chance to the peer to establish the initial connection.
|
||||
// With it, we can decrease to send necessary offer
|
||||
select {
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
|
||||
ticker := conn.prepareExponentTicker()
|
||||
defer ticker.Stop()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
if t.IsZero() {
|
||||
// in case if the ticker has been canceled by context then avoid the temporary loop
|
||||
return
|
||||
}
|
||||
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
|
||||
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
}
|
||||
} else {
|
||||
if conn.statusICE.Get() == StatusDisconnected {
|
||||
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
||||
}
|
||||
}
|
||||
|
||||
// checks if there is peer connection is established via relay or ice
|
||||
if conn.isConnected() {
|
||||
continue
|
||||
}
|
||||
|
||||
err := conn.handshaker.sendOffer()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to do handshake: %v", err)
|
||||
}
|
||||
|
||||
case <-conn.reconnectCh:
|
||||
ticker.Stop()
|
||||
ticker = conn.prepareExponentTicker()
|
||||
|
||||
case <-conn.ctx.Done():
|
||||
conn.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
Multiplier: 2,
|
||||
MaxInterval: conn.config.Timeout,
|
||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, conn.ctx)
|
||||
|
||||
ticker := backoff.NewTicker(bo)
|
||||
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
|
||||
|
||||
return ticker
|
||||
}
|
||||
|
||||
// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
|
||||
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
|
||||
// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
|
||||
// mean that to switch to it. We always force to use the higher priority connection.
|
||||
func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
|
||||
for {
|
||||
select {
|
||||
case changed := <-conn.relayDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
conn.log.Debugf("Relay state changed, try to send new offer")
|
||||
case changed := <-conn.iCEDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
conn.log.Debugf("ICE state changed, try to send new offer")
|
||||
case <-conn.ctx.Done():
|
||||
conn.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
|
||||
err := conn.handshaker.SendOffer()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to do handshake: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
conn.mu.Lock()
|
||||
@ -516,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
|
||||
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||
conn.guard.SetICEConnDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@ -607,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||
conn.guard.SetRelayedConnDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@ -620,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-conn.guard.Reconnect:
|
||||
conn.log.Debugf("send offer to peer")
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
|
||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
||||
conn.config.WgConfig.RemoteKey,
|
||||
@ -692,11 +589,11 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) waitInitialRandomSleepTime() {
|
||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
minWait := 100
|
||||
maxWait := 800
|
||||
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
|
||||
@ -705,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-conn.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
@ -734,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusDisconnected
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnected() bool {
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
|
||||
defer func() {
|
||||
if !connected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -783,8 +686,13 @@ func (conn *Conn) freeUpConnID() {
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: conn.allowedIP,
|
||||
Port: conn.config.WgConfig.WgListenPort,
|
||||
}
|
||||
|
||||
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
@ -803,20 +711,6 @@ func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
@ -829,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) logTraceConnState() {
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
} else {
|
||||
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||
}
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
@ -1,212 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
const (
|
||||
signalerMonitorPeriod = 5 * time.Second
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ConnMonitor struct {
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
config ConnConfig
|
||||
relayDisconnected chan bool
|
||||
iCEDisconnected chan bool
|
||||
reconnectCh chan struct{}
|
||||
currentCandidates []ice.Candidate
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
|
||||
reconnectCh := make(chan struct{}, 1)
|
||||
cm := &ConnMonitor{
|
||||
signaler: signaler,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
config: config,
|
||||
relayDisconnected: relayDisconnected,
|
||||
iCEDisconnected: iCEDisconnected,
|
||||
reconnectCh: reconnectCh,
|
||||
}
|
||||
return cm, reconnectCh
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) Start(ctx context.Context) {
|
||||
signalerReady := make(chan struct{}, 1)
|
||||
go cm.monitorSignalerReady(ctx, signalerReady)
|
||||
|
||||
localCandidatesChanged := make(chan struct{}, 1)
|
||||
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
|
||||
|
||||
for {
|
||||
select {
|
||||
case changed := <-cm.relayDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
log.Debugf("Relay state changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case changed := <-cm.iCEDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
log.Debugf("ICE state changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-signalerReady:
|
||||
log.Debugf("Signaler became ready, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-localCandidatesChanged:
|
||||
log.Debugf("Local candidates changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
|
||||
if cm.signaler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(signalerMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
lastReady := true
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
currentReady := cm.signaler.Ready()
|
||||
if !lastReady && currentReady {
|
||||
select {
|
||||
case signalerReady <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastReady = currentReady
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
|
||||
ufrag, pwd, err := generateICECredentials()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate ICE credentials: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to handle candidate tick: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
|
||||
log.Debugf("Gathering ICE candidates")
|
||||
|
||||
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := agent.Close(); err != nil {
|
||||
log.Warnf("Failed to close ICE agent: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
gatherDone := make(chan struct{})
|
||||
err = agent.OnCandidate(func(c ice.Candidate) {
|
||||
log.Tracef("Got candidate: %v", c)
|
||||
if c == nil {
|
||||
close(gatherDone)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set ICE candidate handler: %w", err)
|
||||
}
|
||||
|
||||
if err := agent.GatherCandidates(); err != nil {
|
||||
return fmt.Errorf("gather ICE candidates: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("wait for gathering: %w", ctx.Err())
|
||||
case <-gatherDone:
|
||||
}
|
||||
|
||||
candidates, err := agent.GetLocalCandidates()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get local candidates: %w", err)
|
||||
}
|
||||
log.Tracef("Got candidates: %v", candidates)
|
||||
|
||||
if changed := cm.updateCandidates(candidates); changed {
|
||||
select {
|
||||
case localCandidatesChanged <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||
cm.candidatesMu.Lock()
|
||||
defer cm.candidatesMu.Unlock()
|
||||
|
||||
if len(cm.currentCandidates) != len(newCandidates) {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
|
||||
for i, candidate := range cm.currentCandidates {
|
||||
if candidate.Address() != newCandidates[i].Address() {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) triggerReconnect() {
|
||||
select {
|
||||
case cm.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
@ -10,8 +10,9 @@ import (
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -20,7 +21,7 @@ var connConf = ConnConfig{
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
Timeout: time.Second,
|
||||
LocalWgPort: 51820,
|
||||
ICEConfig: ICEConfig{
|
||||
ICEConfig: ice.Config{
|
||||
InterfaceBlackList: nil,
|
||||
},
|
||||
}
|
||||
@ -44,11 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -59,11 +57,8 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -96,11 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -132,11 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
194
client/internal/peer/guard/guard.go
Normal file
194
client/internal/peer/guard/guard.go
Normal file
@ -0,0 +1,194 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnectMaxElapsedTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
// Watch these events:
|
||||
// - Relay client reconnected to home server
|
||||
// - Signal server connection state changed
|
||||
// - ICE connection disconnected
|
||||
// - Relayed connection disconnected
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
Reconnect chan struct{}
|
||||
log *log.Entry
|
||||
isController bool
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan bool
|
||||
iCEConnDisconnected chan bool
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
Reconnect: make(chan struct{}, 1),
|
||||
log: log,
|
||||
isController: isController,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
timeout: timeout,
|
||||
srWatcher: srWatcher,
|
||||
relayedConnDisconnected: make(chan bool, 1),
|
||||
iCEConnDisconnected: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) Start(ctx context.Context) {
|
||||
if g.isController {
|
||||
g.reconnectLoopWithRetry(ctx)
|
||||
} else {
|
||||
g.listenForDisconnectEvents(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
|
||||
select {
|
||||
case g.relayedConnDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) SetICEConnDisconnected(changed bool) {
|
||||
select {
|
||||
case g.iCEConnDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
waitForInitialConnectionTry(ctx)
|
||||
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
|
||||
ticker := g.prepareExponentTicker(ctx)
|
||||
defer ticker.Stop()
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
g.log.Infof("start reconnect loop...")
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
g.triggerOfferSending()
|
||||
}
|
||||
|
||||
case changed := <-g.relayedConnDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
|
||||
case changed := <-g.iCEConnDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
|
||||
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
|
||||
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
|
||||
// mean that to switch to it. We always force to use the higher priority connection.
|
||||
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
|
||||
g.log.Infof("start listen for reconnect events...")
|
||||
for {
|
||||
select {
|
||||
case changed := <-g.relayedConnDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
||||
g.triggerOfferSending()
|
||||
case changed := <-g.iCEConnDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
g.log.Debugf("ICE state changed, try to send new offer")
|
||||
g.triggerOfferSending()
|
||||
case <-srReconnectedChan:
|
||||
g.triggerOfferSending()
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
Multiplier: 2,
|
||||
MaxInterval: g.timeout,
|
||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
ticker := backoff.NewTicker(bo)
|
||||
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
|
||||
|
||||
return ticker
|
||||
}
|
||||
|
||||
func (g *Guard) triggerOfferSending() {
|
||||
select {
|
||||
case g.Reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Give chance to the peer to establish the initial connection.
|
||||
// With it, we can decrease to send necessary offer
|
||||
func waitForInitialConnectionTry(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
}
|
135
client/internal/peer/guard/ice_monitor.go
Normal file
135
client/internal/peer/guard/ice_monitor.go
Normal file
@ -0,0 +1,135 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
const (
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ICEMonitor struct {
|
||||
ReconnectCh chan struct{}
|
||||
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
|
||||
currentCandidates []ice.Candidate
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
cm := &ICEMonitor{
|
||||
ReconnectCh: make(chan struct{}, 1),
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: config,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
ufrag, pwd, err := icemaker.GenerateICECredentials()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate ICE credentials: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
changed, err := cm.handleCandidateTick(ctx, ufrag, pwd)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to check ICE changes: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if changed {
|
||||
onChanged()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||
log.Debugf("Gathering ICE candidates")
|
||||
|
||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := agent.Close(); err != nil {
|
||||
log.Warnf("Failed to close ICE agent: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
gatherDone := make(chan struct{})
|
||||
err = agent.OnCandidate(func(c ice.Candidate) {
|
||||
log.Tracef("Got candidate: %v", c)
|
||||
if c == nil {
|
||||
close(gatherDone)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("set ICE candidate handler: %w", err)
|
||||
}
|
||||
|
||||
if err := agent.GatherCandidates(); err != nil {
|
||||
return false, fmt.Errorf("gather ICE candidates: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("wait for gathering timed out")
|
||||
case <-gatherDone:
|
||||
}
|
||||
|
||||
candidates, err := agent.GetLocalCandidates()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get local candidates: %w", err)
|
||||
}
|
||||
log.Tracef("Got candidates: %v", candidates)
|
||||
|
||||
return cm.updateCandidates(candidates), nil
|
||||
}
|
||||
|
||||
func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||
cm.candidatesMu.Lock()
|
||||
defer cm.candidatesMu.Unlock()
|
||||
|
||||
if len(cm.currentCandidates) != len(newCandidates) {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
|
||||
for i, candidate := range cm.currentCandidates {
|
||||
if candidate.Address() != newCandidates[i].Address() {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func candidateTypesP2P() []ice.CandidateType {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
|
||||
}
|
119
client/internal/peer/guard/sr_watcher.go
Normal file
119
client/internal/peer/guard/sr_watcher.go
Normal file
@ -0,0 +1,119 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
type chNotifier interface {
|
||||
SetOnReconnectedListener(func())
|
||||
Ready() bool
|
||||
}
|
||||
|
||||
type SRWatcher struct {
|
||||
signalClient chNotifier
|
||||
relayManager chNotifier
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
|
||||
cancelIceMonitor context.CancelFunc
|
||||
}
|
||||
|
||||
// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the
|
||||
// Relay connection is reconnected or the Signal client reconnected.
|
||||
func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher {
|
||||
srw := &SRWatcher{
|
||||
signalClient: signalClient,
|
||||
relayManager: relayManager,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: iceConfig,
|
||||
listeners: make(map[chan struct{}]struct{}),
|
||||
}
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.cancelIceMonitor != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Close() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.cancelIceMonitor == nil {
|
||||
return
|
||||
}
|
||||
w.cancelIceMonitor()
|
||||
w.signalClient.SetOnReconnectedListener(nil)
|
||||
w.relayManager.SetOnReconnectedListener(nil)
|
||||
}
|
||||
|
||||
func (w *SRWatcher) NewListener() chan struct{} {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
listenerChan := make(chan struct{}, 1)
|
||||
w.listeners[listenerChan] = struct{}{}
|
||||
return listenerChan
|
||||
}
|
||||
|
||||
func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
delete(w.listeners, listenerChan)
|
||||
close(listenerChan)
|
||||
}
|
||||
|
||||
func (w *SRWatcher) onICEChanged() {
|
||||
if !w.signalClient.Ready() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("network changes detected by ICE agent")
|
||||
w.notify()
|
||||
}
|
||||
|
||||
func (w *SRWatcher) onReconnected() {
|
||||
if !w.signalClient.Ready() {
|
||||
return
|
||||
}
|
||||
if !w.relayManager.Ready() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("reconnected to Signal or Relay server")
|
||||
w.notify()
|
||||
}
|
||||
|
||||
func (w *SRWatcher) notify() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
for listener := range w.listeners {
|
||||
select {
|
||||
case listener <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user