Merge branch 'main' into refactor-get-account-by-token

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-30 22:34:59 +03:00
commit 901d283114
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
184 changed files with 8046 additions and 2469 deletions

View File

@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
@ -79,9 +79,6 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@ -98,7 +95,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
@ -106,7 +103,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif
ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.15"
SIGN_PIPE_VER: "v0.0.16"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval

View File

@ -3,7 +3,6 @@
package firewall
import (
"context"
"fmt"
"runtime"
@ -11,10 +10,11 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@ -3,7 +3,7 @@
package firewall
import (
"context"
"errors"
"fmt"
"os"
@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var errFw error
fm, err := createNativeFirewall(iface, stateManager)
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
return nbiptables.Create(iface)
case NFTABLES:
log.Info("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
return nbnftables.Create(iface)
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
}
if errFw != nil {
return nil, errFw
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}

View File

@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
@ -22,6 +23,8 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
@ -32,9 +35,11 @@ type aclManager struct {
wgIface iFaceMapper
routingFwChainName string
entries map[string][][]string
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
ipsetStore: newIpsetStore(),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
err = m.cleanChains()
if err != nil {
return nil, err
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
err = m.createDefaultChains()
if err != nil {
return nil, err
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
return m, nil
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
chain: chain,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
return err
m.updateState()
return nil
}
func (m *aclManager) Reset() error {
return m.cleanChains()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
}
// todo write less destructive cleanup mechanism
@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,

View File

@ -8,10 +8,13 @@ import (
"sync"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager of iptables firewall
@ -33,10 +36,10 @@ type iFaceMapper interface {
}
// Create iptables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
return nil, fmt.Errorf("init iptables: %w", err)
}
m := &Manager{
@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(context, iptablesClient, wgIface)
m.router, err = newRouter(iptablesClient, wgIface)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
return nil, fmt.Errorf("create acl manager: %w", err)
}
return m, nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
errAcl := m.aclMgr.Reset()
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
var merr *multierror.Error
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
errMgr := m.router.Reset()
if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic

View File

@ -1,7 +1,6 @@
package iptables
import (
"context"
"fmt"
"net"
"testing"
@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err)
// just check on the local interface
manager, err := Create(context.Background(), ifaceMock)
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset()
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
})
}
@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)

View File

@ -3,7 +3,6 @@
package iptables
import (
"context"
"fmt"
"net/netip"
"strconv"
@ -18,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -48,28 +48,31 @@ type routeFilteringRuleParams struct {
SetName string
}
type routeRules map[string][]string
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
rules routeRules
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
stateManager *statemanager.Manager
}
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
return nil, fmt.Errorf("init ipset: %w", err)
}
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("cleanup routing rules: %s", err)
return nil, err
return r, nil
}
func (r *router) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("create containers for route: %s", err)
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
return r, err
r.updateState()
return nil
}
func (r *router) AddRouteFiltering(
@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering(
r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil
}
@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
log.Debugf("route rule %s not found", ruleKey)
}
r.updateState()
return nil
}
@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string {
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return struct{}{}, nil
return nil
}
func (r *router) deleteIpSet(setName string) error {
@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
@ -278,8 +296,13 @@ func (r *router) RemoveAllLegacyRouteRules() error {
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
@ -294,6 +317,8 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, err)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
@ -431,6 +456,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"

View File

@ -3,7 +3,6 @@
package iptables
import (
"context"
"net/netip"
"os/exec"
"testing"
@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() {
_ = manager.Reset()
@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
err := manager.Reset()
@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
_ = manager.Reset()
}()
@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()

View File

@ -1,14 +1,16 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) ipList {
func newIpList(ip string) *ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return ipList{
return &ipList{
ips: ips,
}
}
@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
return nil
}
type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset
ipsets map[string]*ipList
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]ipList),
ipsets: make(map[string]*ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
return nil
}

View File

@ -0,0 +1,70 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@ -10,6 +10,8 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -52,6 +54,8 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
@ -91,7 +95,7 @@ type Manager interface {
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset() error
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error

View File

@ -17,7 +17,6 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
@ -56,13 +55,6 @@ type AclManager struct {
rules map[string]*Rule
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
return nil, fmt.Errorf("create nf conn: %w", err)
}
m := &AclManager{
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}
}, nil
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall

View File

@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -24,6 +26,13 @@ const (
chainNameInput = "INPUT"
)
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
@ -35,30 +44,68 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable, err := m.createWorkTable()
if err != nil {
return nil, err
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
m.router, err = newRouter(context, workTable, wgIface)
var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil {
return nil, err
return nil, fmt.Errorf("create router: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, err
return nil, fmt.Errorf("create acl manager: %w", err)
}
return m, nil
}
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
@ -183,68 +230,84 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
oldLegacy := m.router.legacyManagement
return firewall.SetLegacyManagement(m.router, isLegacy)
}
if oldLegacy != isLegacy {
m.router.legacyManagement = isLegacy
log.Debugf("Set legacy management to %v", isLegacy)
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
if !isLegacy && oldLegacy {
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
log.Debugf("Legacy routing rules removed")
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset forward rules: %v", err)
}
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
return fmt.Errorf("list tables: %w", err)
}
for _, t := range tables {
if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
return m.rConn.Flush()
return nil
}
// Flush rule/chain/set operations from the buffer

View File

@ -1,7 +1,6 @@
package nftables
import (
"context"
"fmt"
"net"
"net/netip"
@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(context.Background(), ifaceMock)
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset()
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset()
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
}
@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
if err := manager.Reset(); err != nil {
if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)

View File

@ -2,7 +2,6 @@ package nftables
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
@ -40,8 +39,6 @@ var (
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
@ -54,12 +51,8 @@ type router struct {
legacyManagement bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
return nil, fmt.Errorf("load filter table: %w", err)
}
}
err = r.removeAcceptForwardRules()
if err != nil {
return r, nil
}
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
return r, err
return nil
}
// Reset cleans existing nftables default forward rules from the system
@ -553,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@ -3,7 +3,6 @@
package nftables
import (
"context"
"encoding/binary"
"net/netip"
"os/exec"
@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock)
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{}
@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock)
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{}
@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")

View File

@ -0,0 +1 @@
package nftables

View File

@ -0,0 +1,47 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@ -2,8 +2,10 @@
package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
// Reset firewall to the default state
func (m *Manager) Reset() error {
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset()
return m.nativeFirewall.Reset(stateManager)
}
return nil
}

View File

@ -6,6 +6,8 @@ import (
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type action string
@ -17,7 +19,7 @@ const (
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const layerTypeAll = 0
@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil
}
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil {
return false
@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil
}
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
@ -232,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
}
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(_ bool) error {
return nil
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
// Flush doesn't need to be implemented for this manager

View File

@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
return
}
err = m.Reset()
err = m.Reset(nil)
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if err = m.Reset(); err != nil {
if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
}
@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
defer func() {
if err := manager.Reset(); err != nil {
if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)

View File

@ -1,142 +0,0 @@
package bind
import (
"fmt"
"net"
"runtime"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
muUDPMux sync.Mutex
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

View File

@ -0,0 +1,5 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@ -0,0 +1,275 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}

View File

@ -5,7 +5,6 @@ package device
import (
"strings"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
@ -31,13 +30,13 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
tunAdapter: tunAdapter,
}
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@ -29,14 +28,14 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
}
}

View File

@ -6,7 +6,6 @@ package device
import (
"os"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
@ -30,13 +29,13 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
tunFd: tunFd,
}
}

View File

@ -6,7 +6,6 @@ package device
import (
"fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
@ -31,7 +30,7 @@ type TunNetstackDevice struct {
configurer WGConfigurer
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
}
}

View File

@ -7,7 +7,6 @@ import (
"os"
"runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@ -30,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn)}
iceBind: iceBind,
}
}
func (t *USPDevice) Create() (WGConfigurer, error) {

View File

@ -4,7 +4,6 @@ import (
"fmt"
"net/netip"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
@ -32,14 +31,14 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
}
}

View File

@ -6,12 +6,16 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
const (
@ -22,14 +26,35 @@ const (
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
}
type WGIFaceOpts struct {
IFaceName string
Address string
WGPort int
WGPrivKey string
MTU int
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
}
// WGIface represents an interface instance
type WGIface struct {
tun WGTunDevice
userspaceBind bool
mu sync.Mutex
configurer device.WGConfigurer
filter device.PacketFilter
configurer device.WGConfigurer
filter device.PacketFilter
wgProxyFactory wgProxyFactory
}
func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@ -124,22 +149,26 @@ func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
var result *multierror.Error
if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
err = w.waitUntilRemoved()
if err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
if err := w.Destroy(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
return errors.FormatErrorOrNil(result)
}
log.Infof("interface %s successfully removed", w.Name())
}
return nil
return errors.FormatErrorOrNil(result)
}
// SetFilter sets packet filters for the userspace implementation

View File

@ -1,43 +0,0 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@ -2,6 +2,8 @@
package iface
import "fmt"
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
@ -17,3 +19,8 @@ func (w *WGIface) Create() error {
w.configurer = cfgr
return nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}

View File

@ -0,0 +1,24 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@ -7,39 +7,8 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
@ -65,3 +34,8 @@ func (w *WGIface) Create() error {
return backoff.Retry(operation, backOff)
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@ -0,0 +1,10 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@ -1,31 +0,0 @@
//go:build ios
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWGIface struct {
@ -30,6 +31,7 @@ type MockWGIface struct {
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

View File

@ -0,0 +1,24 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@ -0,0 +1,34 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@ -0,0 +1,26 @@
//go:build ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@ -0,0 +1,45 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@ -0,0 +1,32 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: addr,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -153,7 +171,16 @@ func Test_Close(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) {
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: peer1wgIP,
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface1, err := NewWGIFace(optsPeer1)
if err != nil {
t.Fatal(err)
}
@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort))
localIP, err := getLocalIP()
if err != nil {
t.Fatal(err)
}
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort))
if err != nil {
t.Fatal(err)
}
@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: peer2wgIP,
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface2, err := NewWGIFace(optsPeer2)
if err != nil {
t.Fatal(err)
}
@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort))
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort))
if err != nil {
t.Fatal(err)
}
@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
func getLocalIP() (string, error) {
// Get all interfaces
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
if ipNet.IP.To4() == nil {
continue
}
return ipNet.IP.String(), nil
}
return "", fmt.Errorf("no local IP found")
}

View File

@ -1,49 +0,0 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"runtime"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
}

View File

@ -1,41 +0,0 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type IWGIface interface {
@ -22,6 +23,7 @@ type IWGIface interface {
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error

View File

@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type IWGIface interface {
@ -20,6 +21,7 @@ type IWGIface interface {
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error

View File

@ -0,0 +1,137 @@
package bind
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
)
type ProxyBind struct {
Bind *bind.ICEBind
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyBind) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *ProxyBind) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closed {
return nil
}
p.closed = true
p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close()
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
}()
buf := make([]byte, 1500)
for {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
msg := bind.RecvMessage{
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
}
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@ -5,9 +5,9 @@ import (
"net"
)
const (
var (
portRangeStart = 3128
portRangeEnd = 3228
portRangeEnd = portRangeStart + 100
)
type portLookup struct {

View File

@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{}
portRangeStart = 4128
portRangeEnd = portRangeStart + 100
allocatedPort, err := allocatePort(portRangeStart)
if err != nil {
t.Fatal(err)

View File

@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error {
p.ctxCancel()
var result *multierror.Error
if p.conn != nil { // p.conn will be nil if we have failed to listen
if p.conn != nil {
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
}

View File

@ -28,7 +28,7 @@ type ProxyWrapper struct {
isStarted bool
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)

View File

@ -0,0 +1,49 @@
//go:build linux && !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
type KernelFactory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy
return f
}
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@ -0,0 +1,29 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@ -0,0 +1,30 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
}
func (w *USPFactory) Free() error {
return nil
}

View File

@ -0,0 +1,15 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
}

View File

@ -0,0 +1,56 @@
//go:build linux && !android
package wgproxy
import (
"context"
"os"
"testing"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@ -11,8 +11,8 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util"
)
@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
proxy: udpProxy.NewWGUDPProxy(51830),
},
}
@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, relayedConn)
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}

View File

@ -1,19 +1,21 @@
package usp
package udp
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
cerrors "github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
// WGUDPProxy proxies
type WGUDPProxy struct {
localWGListenPort int
remoteConn net.Conn
@ -28,10 +30,10 @@ type WGUserSpaceProxy struct {
isStarted bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
p := &WGUDPProxy{
localWGListenPort: wgPort,
}
return p
@ -42,7 +44,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
@ -57,7 +59,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn)
return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
@ -66,7 +68,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
func (p *WGUDPProxy) Work() {
if p.remoteConn == nil {
return
}
@ -83,7 +85,7 @@ func (p *WGUserSpaceProxy) Work() {
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
func (p *WGUDPProxy) Pause() {
if p.remoteConn == nil {
return
}
@ -94,14 +96,14 @@ func (p *WGUserSpaceProxy) Pause() {
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
func (p *WGUDPProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
@ -121,11 +123,11 @@ func (p *WGUserSpaceProxy) close() error {
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
return cerrors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
@ -157,21 +159,19 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
// proxyToLocal proxies from the Remote peer to local WireGuard
// if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
if !errors.Is(err, io.EOF) {
log.Warnf("error in proxy to local loop: %s", err)
}
}
}()
buf := make([]byte, 1500)
for {
n, err := p.remoteConn.Read(buf)
n, err := p.remoteConnRead(ctx, buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
@ -193,3 +193,15 @@ func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
}
}
}
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
return
}
return
}

View File

@ -3,6 +3,7 @@ package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
@ -10,14 +11,18 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
var newRouteRules = make(map[id.RuleID]struct{})
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
if err != nil {
return fmt.Errorf("apply route ACL: %w", err)
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
continue
}
newRouteRules[id] = struct{}{}
}
// Clean up old firewall rules
for id := range d.routeRules {
if _, ok := newRouteRules[id]; !ok {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
log.Errorf("failed to delete route firewall rule: %v", err)
continue
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
delete(d.routeRules, id)
// implicitly deleted from the map
}
}
d.routeRules = newRouteRules
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", fmt.Errorf("source ranges is empty")
return "", ErrSourceRangesEmpty
}
var sources []netip.Prefix

View File

@ -1,7 +1,6 @@
package acl
import (
"context"
"net"
"testing"
@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset()
_ = fw.Reset(nil)
}(fw)
acl := NewDefaultManager(fw)
@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset()
_ = fw.Reset(nil)
}(fw)
acl := NewDefaultManager(fw)

View File

@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
runningChan chan error,
) error {
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
return c.run(MobileDependency{}, probes, runningChan)
}
@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(
mobileDependency MobileDependency,
probes *ProbeHolder,
runningChan chan error,
) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
@ -117,12 +110,6 @@ func (c *ConnectClient) run(
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@ -358,7 +345,11 @@ func (c *ConnectClient) Stop() error {
if c.engine == nil {
return nil
}
return c.engine.Stop()
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
}
func (c *ConnectClient) isContextCancelled() bool {

View File

@ -1,6 +1,5 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
)

View File

@ -3,6 +3,5 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
)

View File

@ -9,6 +9,8 @@ import (
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
var (
@ -20,7 +22,7 @@ var (
}
)
type repairConfFn func([]string, string, *resolvConf) error
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
type repair struct {
operationFile string
@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
}
}
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
if f.inotify != nil {
return
}
@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
}
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
if err != nil {
log.Errorf("failed to repair resolv.conf: %v", err)
}

View File

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/util"
)
@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error {
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true
cancel()
return nil
}
r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil {
@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error {
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true
cancel()
return nil
}
r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil {

View File

@ -11,6 +11,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -36,7 +38,7 @@ type fileConfigurator struct {
nbNameserverIP string
}
func newFileConfigurator() (hostManager, error) {
func newFileConfigurator() (*fileConfigurator, error) {
fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil
@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false
}
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
backupFileExist := f.isBackupFileExist()
if !config.RouteAll {
if backupFileExist {
@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
if err != nil {
return err
}
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
return nil
}
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg)
@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile()
}
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
return nil
}
@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil
}

View File

@ -5,14 +5,14 @@ import (
"net/netip"
"strings"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig) error
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
}
type SystemDNSSettings struct {
@ -35,15 +35,15 @@ type DomainConfig struct {
}
type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig) error
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNSFunc func() error
supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
}
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config)
return m.applyDNSConfigFunc(config, stateManager)
}
return fmt.Errorf("method applyDNSSettings is not implemented")
}
@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false
}
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager {
return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },

View File

@ -1,15 +1,17 @@
package dns
import "net/netip"
import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type androidHostManager struct {
}
func newHostManager() (hostManager, error) {
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
}
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
return nil
}
@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool {
return false
}
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -37,7 +38,7 @@ type systemConfigurator struct {
systemDNSSettings SystemDNSSettings
}
func newHostManager() (hostManager, error) {
func newHostManager() (*systemConfigurator, error) {
return &systemConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true
}
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var (
@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
}
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil
}
@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)
}

View File

@ -3,9 +3,10 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type iosHostManager struct {
@ -13,13 +14,13 @@ type iosHostManager struct {
config HostDNSConfig
}
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
return &iosHostManager{
dnsManager: dnsManager,
}, nil
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("marshal: %w", err)
@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool {
return false
}
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -4,9 +4,9 @@ package dns
import (
"bufio"
"errors"
"fmt"
"io"
"net/netip"
"os"
"strings"
@ -21,27 +21,8 @@ const (
resolvConfManager
)
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
type osManagerType int
func newOsManagerType(osManager string) (osManagerType, error) {
switch osManager {
case "netbird":
return fileManager, nil
case "file":
return netbirdManager, nil
case "networkManager":
return networkManager, nil
case "systemd":
return systemdManager, nil
case "resolvconf":
return resolvConfManager, nil
default:
return 0, ErrUnknownOsManagerType
}
}
func (t osManagerType) String() string {
switch t {
case netbirdManager:
@ -59,6 +40,11 @@ func (t osManagerType) String() string {
}
}
type restoreHostManager interface {
hostManager
restoreUncleanShutdownDNS(*netip.Addr) error
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return newHostManagerFromType(wgInterface, osManager)
}
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
switch osManager {
case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface)

View File

@ -3,11 +3,12 @@ package dns
import (
"fmt"
"io"
"net/netip"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@ -31,7 +32,7 @@ type registryConfigurator struct {
routingAll bool
}
func newHostManager(wgInterface WGIface) (hostManager, error) {
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil {
return nil, err
@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
return newHostManagerWithGuid(guid)
}
func newHostManagerWithGuid(guid string) (hostManager, error) {
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
return &registryConfigurator{
guid: guid,
}, nil
@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
return false
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP)
@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(r.guid); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var (
@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil
}
@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
return regKey, nil
}
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err)
}

View File

@ -16,6 +16,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbversion "github.com/netbirdio/netbird/version"
)
@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
}
}
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("get nm dbus: %w", err)
@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false
}
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil {
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
// The file content itself is not important for network-manager restoration
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
state := &ShutdownState{
ManagerType: networkManager,
WgIface: n.ifaceName,
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("delete connection settings: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil
}

View File

@ -9,6 +9,8 @@ import (
"os/exec"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const resolvconfCommand = "resolvconf"
@ -22,7 +24,7 @@ type resolvconf struct {
}
// supported "openresolv" only
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
resolvConfEntries, err := parseDefaultResolvConf()
if err != nil {
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
return false
}
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if !config.RouteAll {
err = r.restoreHostDNS()
@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
append([]string{config.ServerIP}, r.originalNameServers...),
options)
// create a backup for unclean shutdown detection before the resolv.conf is changed
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
state := &ShutdownState{
ManagerType: resolvConfManager,
WgIface: r.ifaceName,
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
err = r.applyConfig(buf)
@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
}
return nil
@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
cmd.Stdin = &content
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
}
return nil
}

View File

@ -7,6 +7,7 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
@ -14,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
@ -63,6 +65,7 @@ type DefaultServer struct {
iosDnsManager IosDnsManager
statusRecorder *peer.Status
stateManager *statemanager.Manager
}
type handlerWithStop interface {
@ -77,12 +80,7 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@ -99,7 +97,7 @@ func NewDefaultServer(
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@ -130,12 +128,12 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
},
wgInterface: wgInterface,
statusRecorder: statusRecorder,
stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(),
}
@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) {
}
}
s.stateManager.RegisterState(&ShutdownState{})
s.hostManager, err = s.initialize()
if err != nil {
return fmt.Errorf("initialize: %w", err)
@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() {
s.ctxCancel()
if s.hostManager != nil {
err := s.hostManager.restoreHostDNS()
if err != nil {
log.Error(err)
if err := s.hostManager.restoreHostDNS(); err != nil {
log.Error("failed to restore host DNS settings: ", err)
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete shutdown dns state: %v", err)
}
}
@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate.RouteAll = false
}
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
log.Error(err)
}
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
@ -521,10 +529,17 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
}
@ -551,7 +566,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler)
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
@ -267,7 +268,17 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
@ -281,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil {
t.Fatal(err)
}
@ -345,7 +356,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: "100.66.100.1/32",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
@ -382,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@ -477,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
if err != nil {
t.Fatalf("%v", err)
}
@ -536,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
@ -552,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
@ -803,7 +823,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: "100.66.100.2/24",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err

View File

@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
func (s *DefaultServer) initialize() (hostManager, error) {
return newHostManager(s.wgInterface)
}

View File

@ -7,7 +7,7 @@ import (
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}

View File

@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
@ -38,6 +39,7 @@ const (
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool
}
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
iface, err := net.InterfaceByName(wgInterface)
if err != nil {
return nil, fmt.Errorf("get interface: %w", err)
@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true
}
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err)
@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
}
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
// The file content itself is not important for systemd restoration
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
state := &ShutdownState{
ManagerType: systemdManager,
WgIface: s.ifaceName,
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return s.flushCaches()
}

View File

@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -3,57 +3,25 @@
package dns
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
)
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
type ShutdownState struct {
}
func CheckUncleanShutdown(string) error {
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManager()
if err != nil {
return fmt.Errorf("create host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}
return nil
}
func createUncleanShutdownIndicator() error {
dir := filepath.Dir(fileUncleanShutdownFileLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}

View File

@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -0,0 +1,14 @@
//go:build ios || android
package dns
type ShutdownState struct {
}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}

View File

@ -3,66 +3,44 @@
package dns
import (
"errors"
"fmt"
"io/fs"
"net/netip"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
type ShutdownState struct {
ManagerType osManagerType
DNSAddress netip.Addr
WgIface string
}
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
func (s *ShutdownState) Name() string {
return "dns_state"
}
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
if err != nil {
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
managerFields := strings.Split(string(managerData), ",")
if len(managerFields) < 2 {
return errors.New("split manager data: insufficient number of fields")
}
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
}
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
// determine os manager type, so we can invoke the respective restore action
osManagerType, err := newOsManagerType(osManagerTypeStr)
if err != nil {
return fmt.Errorf("detect previous host manager: %w", err)
}
manager, err := newHostManagerFromType(wgIface, osManagerType)
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
if err != nil {
return fmt.Errorf("create previous host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}
return nil
}
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
// TODO: move file contents to state manager
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
}
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
return fmt.Errorf("create %s: %w", sourcePath, err)
}
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
}
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
state := &ShutdownState{
ManagerType: fileManager,
DNSAddress: dnsAddress,
}
if err := stateManager.UpdateState(state); err != nil {
return fmt.Errorf("update state: %w", err)
}
return nil
}

View File

@ -1,75 +1,26 @@
package dns
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
const (
netbirdProgramDataLocation = "Netbird"
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
)
type ShutdownState struct {
Guid string
}
func CheckUncleanShutdown(string) error {
file := getUncleanShutdownFile()
func (s *ShutdownState) Name() string {
return "dns_state"
}
if _, err := os.Stat(file); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
guid, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("read %s: %w", file, err)
}
manager, err := newHostManagerWithGuid(string(guid))
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManagerWithGuid(s.Guid)
if err != nil {
return fmt.Errorf("create host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}
return nil
}
func createUncleanShutdownIndicator(guid string) error {
file := getUncleanShutdownFile()
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
return fmt.Errorf("create %s: %w", file, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
file := getUncleanShutdownFile()
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", file, err)
}
return nil
}
func getUncleanShutdownFile() string {
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
}

View File

@ -23,19 +23,21 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
@ -141,8 +143,7 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
wgInterface iface.IWGIface
wgProxyFactory *wgproxy.Factory
wgInterface iface.IWGIface
udpMux *bind.UniversalUDPMuxDefault
@ -168,6 +169,8 @@ type Engine struct {
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
}
// Peer is an instance of the Connection Peer
@ -215,7 +218,7 @@ func NewEngineWithProbes(
probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
@ -234,6 +237,11 @@ func NewEngineWithProbes(
probes: probes,
checks: checks,
}
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
}
return engine
}
func (e *Engine) Stop() error {
@ -255,7 +263,11 @@ func (e *Engine) Stop() error {
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
err := e.removeAllPeers()
@ -277,6 +289,17 @@ func (e *Engine) Stop() error {
e.close()
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
@ -299,9 +322,6 @@ func (e *Engine) Start() error {
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
@ -319,6 +339,8 @@ func (e *Engine) Start() error {
}
}
e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil {
e.close()
@ -327,7 +349,7 @@ func (e *Engine) Start() error {
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
@ -344,7 +366,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("create wg interface: %w", err)
}
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil {
log.Errorf("failed creating firewall manager: %s", err)
}
@ -374,6 +396,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err)
}
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveProbeEvents()
@ -956,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
ICEConfig: peer.ICEConfig{
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
@ -966,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
},
}
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
if err != nil {
return nil, err
}
@ -1117,12 +1151,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
}
func (e *Engine) close() {
if e.wgProxyFactory != nil {
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
@ -1139,7 +1167,7 @@ func (e *Engine) close() {
}
if e.firewall != nil {
err := e.firewall.Reset()
err := e.firewall.Reset(e.stateManager)
if err != nil {
log.Warnf("failed to reset firewall: %s", err)
}
@ -1167,21 +1195,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
log.Errorf("failed to create pion's stdnet: %s", err)
}
var mArgs *device.MobileIFaceArguments
opts := iface.WGIFaceOpts{
IFaceName: e.config.WgIfaceName,
Address: e.config.WgAddr,
WGPort: e.config.WgPort,
WGPrivKey: e.config.WgPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
}
switch runtime.GOOS {
case "android":
mArgs = &device.MobileIFaceArguments{
opts.MobileArgs = &device.MobileIFaceArguments{
TunAdapter: e.mobileDep.TunAdapter,
TunFd: int(e.mobileDep.FileDescriptor),
}
case "ios":
mArgs = &device.MobileIFaceArguments{
opts.MobileArgs = &device.MobileIFaceArguments{
TunFd: int(e.mobileDep.FileDescriptor),
}
default:
}
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
return iface.NewWGIFace(opts)
}
func (e *Engine) wgInterfaceCreate() (err error) {
@ -1222,10 +1258,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
if err != nil {
return nil, nil, err
}
return nil, dnsServer, nil
}
}

View File

@ -29,6 +29,8 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
type testCase struct {
name string
@ -602,7 +605,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: wgIfaceName,
Address: wgAddr,
WGPort: engine.config.WgPort,
WGPrivKey: key.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
engine.wgInterface, err = iface.NewWGIFace(opts)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@ -774,7 +786,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: wgIfaceName,
Address: wgAddr,
WGPort: 33100,
WGPrivKey: key.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
engine.wgInterface, err = iface.NewWGIFace(opts)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{

View File

@ -10,15 +10,16 @@ import (
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
@ -32,8 +33,6 @@ const (
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2
reconnectMaxElapsedTime = 30 * time.Minute
)
type WgConfig struct {
@ -63,7 +62,7 @@ type ConnConfig struct {
RosenpassAddr string
// ICEConfig ICE protocol configuration
ICEConfig ICEConfig
ICEConfig icemaker.Config
}
type WorkerCallbacks struct {
@ -81,11 +80,10 @@ type Conn struct {
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
wgProxyFactory *wgproxy.Factory
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
allowedIPsIP string
allowedIP net.IP
allowedNet string
handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
@ -107,17 +105,13 @@ type Conn struct {
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations
iCEDisconnected chan bool
relayDisconnected chan bool
connMonitor *ConnMonitor
reconnectCh <-chan struct{}
guard *guard.Guard
}
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
_, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps)
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err
@ -132,26 +126,14 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
ctxCancel: ctxCancel,
config: config,
statusRecorder: statusRecorder,
wgProxyFactory: wgProxyFactory,
signaler: signaler,
iFaceDiscover: iFaceDiscover,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
allowedIP: allowedIP,
allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
signaler,
iFaceDiscover,
config,
conn.relayDisconnected,
conn.iCEDisconnected,
)
rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
@ -162,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
OnStatusChanged: conn.onWorkerICEStateDisconnected,
}
conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns)
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
@ -177,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen()
return conn, nil
@ -187,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// be used.
func (conn *Conn) Open() {
conn.log.Debugf("open connection to peer")
conn.mu.Lock()
defer conn.mu.Unlock()
conn.opened = true
@ -203,24 +189,19 @@ func (conn *Conn) Open() {
conn.log.Warnf("error while updating the state err: %v", err)
}
go conn.startHandshakeAndReconnect()
go conn.startHandshakeAndReconnect(conn.ctx)
}
func (conn *Conn) startHandshakeAndReconnect() {
conn.waitInitialRandomSleepTime()
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
conn.waitInitialRandomSleepTime(ctx)
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err)
}
go conn.connMonitor.Start(conn.ctx)
if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry()
} else {
conn.reconnectLoopForOnDisconnectedEvent()
}
go conn.guard.Start(ctx)
go conn.listenGuardEvent(ctx)
}
// Close closes this peer Conn issuing a close event to the Conn closeCh
@ -319,104 +300,6 @@ func (conn *Conn) GetKey() string {
return conn.config.Key
}
func (conn *Conn) reconnectLoopWithRetry() {
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
select {
case <-conn.ctx.Done():
return
case <-time.After(3 * time.Second):
}
ticker := conn.prepareExponentTicker()
defer ticker.Stop()
time.Sleep(1 * time.Second)
for {
select {
case t := <-ticker.C:
if t.IsZero() {
// in case if the ticker has been canceled by context then avoid the temporary loop
return
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
// checks if there is peer connection is established via relay or ice
if conn.isConnected() {
continue
}
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
case <-conn.reconnectCh:
ticker.Stop()
ticker = conn.prepareExponentTicker()
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: conn.config.Timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
for {
select {
case changed := <-conn.relayDisconnected:
if !changed {
continue
}
conn.log.Debugf("Relay state changed, try to send new offer")
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, try to send new offer")
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
err := conn.handshaker.SendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
}
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock()
@ -516,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
conn.notifyReconnectLoopICEDisconnected(changed)
conn.guard.SetICEConnDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@ -607,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed)
conn.guard.SetRelayedConnDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@ -620,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
}
}
func (conn *Conn) listenGuardEvent(ctx context.Context) {
for {
select {
case <-conn.guard.Reconnect:
conn.log.Debugf("send offer to peer")
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
@ -692,11 +589,11 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
}
}
func (conn *Conn) waitInitialRandomSleepTime() {
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
minWait := 100
maxWait := 800
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
@ -705,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
defer timeout.Stop()
select {
case <-conn.ctx.Done():
case <-ctx.Done():
case <-timeout.C:
}
}
@ -734,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusDisconnected
}
func (conn *Conn) isConnected() bool {
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
defer func() {
if !connected {
conn.logTraceConnState()
}
}()
if conn.statusICE.Get() == StatusDisconnected {
return false
}
@ -783,8 +686,13 @@ func (conn *Conn) freeUpConnID() {
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
wgProxy := conn.wgProxyFactory.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
udpAddr := &net.UDPAddr{
IP: conn.allowedIP,
Port: conn.config.WgConfig.WgListenPort,
}
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return nil, err
}
@ -803,20 +711,6 @@ func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
@ -829,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
}
}
func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
}
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}

View File

@ -1,212 +0,0 @@
package peer
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
signalerMonitorPeriod = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ConnMonitor struct {
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
config ConnConfig
relayDisconnected chan bool
iCEDisconnected chan bool
reconnectCh chan struct{}
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
reconnectCh := make(chan struct{}, 1)
cm := &ConnMonitor{
signaler: signaler,
iFaceDiscover: iFaceDiscover,
config: config,
relayDisconnected: relayDisconnected,
iCEDisconnected: iCEDisconnected,
reconnectCh: reconnectCh,
}
return cm, reconnectCh
}
func (cm *ConnMonitor) Start(ctx context.Context) {
signalerReady := make(chan struct{}, 1)
go cm.monitorSignalerReady(ctx, signalerReady)
localCandidatesChanged := make(chan struct{}, 1)
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
for {
select {
case changed := <-cm.relayDisconnected:
if !changed {
continue
}
log.Debugf("Relay state changed, triggering reconnect")
cm.triggerReconnect()
case changed := <-cm.iCEDisconnected:
if !changed {
continue
}
log.Debugf("ICE state changed, triggering reconnect")
cm.triggerReconnect()
case <-signalerReady:
log.Debugf("Signaler became ready, triggering reconnect")
cm.triggerReconnect()
case <-localCandidatesChanged:
log.Debugf("Local candidates changed, triggering reconnect")
cm.triggerReconnect()
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
if cm.signaler == nil {
return
}
ticker := time.NewTicker(signalerMonitorPeriod)
defer ticker.Stop()
lastReady := true
for {
select {
case <-ticker.C:
currentReady := cm.signaler.Ready()
if !lastReady && currentReady {
select {
case signalerReady <- struct{}{}:
default:
}
}
lastReady = currentReady
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
ufrag, pwd, err := generateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
log.Warnf("Failed to handle candidate tick: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
log.Debugf("Gathering ICE candidates")
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("wait for gathering: %w", ctx.Err())
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
if changed := cm.updateCandidates(candidates); changed {
select {
case localCandidatesChanged <- struct{}{}:
default:
}
}
return nil
}
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func (cm *ConnMonitor) triggerReconnect() {
select {
case cm.reconnectCh <- struct{}{}:
default:
}
}

View File

@ -10,8 +10,9 @@ import (
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/util"
)
@ -20,7 +21,7 @@ var connConf = ConnConfig{
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
Timeout: time.Second,
LocalWgPort: 51820,
ICEConfig: ICEConfig{
ICEConfig: ice.Config{
InterfaceBlackList: nil,
},
}
@ -44,11 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
if err != nil {
return
}
@ -59,11 +57,8 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil {
return
}
@ -96,11 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil {
return
}
@ -132,11 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil {
return
}

View File

@ -0,0 +1,194 @@
package guard
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
// Watch these events:
// - Relay client reconnected to home server
// - Signal server connection state changed
// - ICE connection disconnected
// - Relayed connection disconnected
// - ICE candidate changes
type Guard struct {
Reconnect chan struct{}
log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan bool
iCEConnDisconnected chan bool
}
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
Reconnect: make(chan struct{}, 1),
log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
relayedConnDisconnected: make(chan bool, 1),
iCEConnDisconnected: make(chan bool, 1),
}
}
func (g *Guard) Start(ctx context.Context) {
if g.isController {
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
}
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
select {
case g.relayedConnDisconnected <- changed:
default:
}
}
func (g *Guard) SetICEConnDisconnected(changed bool) {
select {
case g.iCEConnDisconnected <- changed:
default:
}
}
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
g.triggerOfferSending()
}
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@ -0,0 +1,135 @@
package guard
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ICEMonitor struct {
ReconnectCh chan struct{}
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover,
iceConfig: config,
}
return cm
}
func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
ufrag, pwd, err := icemaker.GenerateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
changed, err := cm.handleCandidateTick(ctx, ufrag, pwd)
if err != nil {
log.Warnf("Failed to check ICE changes: %v", err)
continue
}
if changed {
onChanged()
}
case <-ctx.Done():
return
}
}
}
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return false, fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return false, fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return false, fmt.Errorf("wait for gathering timed out")
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return false, fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
return cm.updateCandidates(candidates), nil
}
func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}

View File

@ -0,0 +1,119 @@
package guard
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
type chNotifier interface {
SetOnReconnectedListener(func())
Ready() bool
}
type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}
// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the
// Relay connection is reconnected or the Signal client reconnected.
func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher {
srw := &SRWatcher{
signalClient: signalClient,
relayManager: relayManager,
iFaceDiscover: iFaceDiscover,
iceConfig: iceConfig,
listeners: make(map[chan struct{}]struct{}),
}
return srw
}
func (w *SRWatcher) Start() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)
}
func (w *SRWatcher) Close() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor == nil {
return
}
w.cancelIceMonitor()
w.signalClient.SetOnReconnectedListener(nil)
w.relayManager.SetOnReconnectedListener(nil)
}
func (w *SRWatcher) NewListener() chan struct{} {
w.mu.Lock()
defer w.mu.Unlock()
listenerChan := make(chan struct{}, 1)
w.listeners[listenerChan] = struct{}{}
return listenerChan
}
func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.listeners, listenerChan)
close(listenerChan)
}
func (w *SRWatcher) onICEChanged() {
if !w.signalClient.Ready() {
return
}
log.Infof("network changes detected by ICE agent")
w.notify()
}
func (w *SRWatcher) onReconnected() {
if !w.signalClient.Ready() {
return
}
if !w.relayManager.Ready() {
return
}
log.Infof("reconnected to Signal or Relay server")
w.notify()
}
func (w *SRWatcher) notify() {
w.mu.Lock()
defer w.mu.Unlock()
for listener := range w.listeners {
select {
case listener <- struct{}{}:
default:
}
}
}

Some files were not shown because too many files have changed in this diff Show More