Bind implementation (#779)

This PR adds supports for the WireGuard userspace implementation
using Bind interface from wireguard-go. 
The newly introduced ICEBind struct implements Bind with UDPMux-based
structs from pion/ice to handle hole punching using ICE.
The core implementation was taken from StdBind of wireguard-go.

The result is a single WireGuard port that is used for host and server reflexive candidates. 
Relay candidates are still handled separately and will be integrated in the following PRs.

ICEBind checks the incoming packets for being STUN or WireGuard ones
and routes them to UDPMux (to handle hole punching) or to WireGuard  respectively.
This commit is contained in:
Misha Bragin 2023-04-13 17:00:01 +02:00 committed by GitHub
parent 0343c5f239
commit 2eeed55c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1992 additions and 408 deletions

View File

@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
runs-on: macos-latest runs-on: macos-latest

View File

@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
strategy: strategy:
@ -66,7 +70,7 @@ jobs:
run: go mod tidy run: go mod tidy
- name: Generate Iface Test bin - name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/... run: go test -c -o iface-testing.bin ./iface/
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/... run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...

View File

@ -6,47 +6,45 @@ on:
- main - main
pull_request: pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
pre:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- run: bash -x wireguard_nt.sh
working-directory: client
- uses: actions/upload-artifact@v2
with:
name: syso
path: client/*.syso
retention-days: 1
test: test:
needs: pre
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
id: go
with: with:
go-version: 1.19.x go-version: 1.19.x
- uses: actions/cache@v2 - name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with: with:
path: | file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
%LocalAppData%\go-build file-name: wintun.zip
~\go\pkg\mod location: ${{ env.downloadPath }}
~\AppData\Local\go-build sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- uses: actions/download-artifact@v2 - name: Decompressing wintun files
with: run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
name: syso
path: iface\
- name: Test - run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@ -1,5 +1,8 @@
name: golangci-lint name: golangci-lint
on: [pull_request] on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
golangci: golangci:
name: lint name: lint

View File

@ -7,7 +7,9 @@ on:
pull_request: pull_request:
paths: paths:
- "release_files/install.sh" - "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
install-cli-only: install-cli-only:
runs-on: macos-latest runs-on: macos-latest

View File

@ -7,7 +7,9 @@ on:
pull_request: pull_request:
paths: paths:
- "release_files/install.sh" - "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
install-cli-only: install-cli-only:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -9,9 +9,13 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.5" SIGN_PIPE_VER: "v0.0.6"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -21,10 +25,6 @@ jobs:
uses: actions/checkout@v2 uses: actions/checkout@v2
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Generate syso with DLL
run: bash -x wireguard_nt.sh
working-directory: client
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
@ -59,6 +59,17 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
- -
name: Run GoReleaser name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v2

View File

@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -26,7 +26,7 @@ type TunAdapter interface {
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
type IFaceDiscover interface { type IFaceDiscover interface {
stdnet.IFaceDiscover stdnet.ExternalIFaceDiscover
} }
func init() { func init() {

View File

@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
Sleep 3000 Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
SetShellVarContext current SetShellVarContext current

View File

@ -27,7 +27,7 @@ const (
) )
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-"} "Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
// ConfigInput carries configuration changes to the client // ConfigInput carries configuration changes to the client
type ConfigInput struct { type ConfigInput struct {

View File

@ -23,7 +23,7 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -108,7 +108,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
@ -194,7 +194,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,

View File

@ -9,7 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@ -199,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil) newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -50,7 +50,7 @@ type EngineConfig struct {
// TunAdapter is option. It is necessary for mobile version. // TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter TunAdapter iface.TunAdapter
IFaceDiscover stdnet.IFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP) // WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string WgAddr string
@ -166,34 +166,56 @@ func (e *Engine) Stop() error {
return nil return nil
} }
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
wgIfaceName := e.config.WgIfaceName wgIFaceName := e.config.WgIfaceName
wgAddr := e.config.WgAddr wgAddr := e.config.WgAddr
myPrivateKey := e.config.WgPrivateKey myPrivateKey := e.config.WgPrivateKey
var err error var err error
transportNet, err := e.newStdNet()
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error()) log.Errorf("failed to create pion's stdnet: %s", err)
}
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet)
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err return err
} }
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
e.close()
return err
}
if e.wgInterface.IsUserspaceBind() {
iceBind := e.wgInterface.GetBind()
udpMux, err := iceBind.GetICEMux()
if err != nil {
e.close()
return err
}
e.udpMux = udpMux.UDPMuxDefault
e.udpMuxSrflx = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else {
networkName := "udp" networkName := "udp"
if e.config.DisableIPv6Discovery { if e.config.DisableIPv6Discovery {
networkName = "udp4" networkName = "udp4"
} }
transportNet, err := e.newStdNet()
if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err)
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort}) e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil { if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error()) log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
@ -213,19 +235,6 @@ func (e *Engine) Start() error {
return err return err
} }
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet}) e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
e.close()
return err
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
@ -496,7 +505,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr, IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(), PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(), FQDN: conf.GetFqdn(),
}) })
@ -822,6 +831,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
ProxyConfig: proxyConfig, ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
@ -1006,12 +1016,6 @@ func (e *Engine) close() {
} }
} }
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil { if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil { if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err) log.Debugf("close udp mux connection: %v", err)

View File

@ -3,9 +3,9 @@
package internal package internal
import ( import (
"github.com/pion/transport/v2/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet() return stdnet.NewNet(e.config.IFaceBlackList)
} }

View File

@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceDiscover) return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList)
} }

View File

@ -3,6 +3,8 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2/stdnet"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@ -207,11 +209,23 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
conn, err := net.ListenUDP("udp4", nil)
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
type testCase struct { type testCase struct {
name string name string
@ -549,7 +563,11 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@ -714,7 +732,11 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{

View File

@ -10,7 +10,6 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@ -46,6 +45,9 @@ type ConnConfig struct {
LocalWgPort int LocalWgPort int
NATExternalIPs []string NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
} }
// OfferAnswer represents a session establishment offer or answer // OfferAnswer represents a session establishment offer or answer
@ -95,7 +97,7 @@ type Conn struct {
meta meta meta meta
adapter iface.TunAdapter adapter iface.TunAdapter
iFaceDiscover stdnet.IFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
} }
// meta holds meta information about a connection // meta holds meta information about a connection
@ -121,7 +123,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) { func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{ return &Conn{
config: config, config: config,
mu: sync.Mutex{}, mu: sync.Mutex{},
@ -136,32 +138,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
}, nil }, nil
} }
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them
func interfaceFilter(blackList []string) func(string) bool {
return func(iFace string) bool {
for _, s := range blackList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}
func (conn *Conn) reCreateAgent() error { func (conn *Conn) reCreateAgent() error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@ -171,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
var err error var err error
transportNet, err := conn.newStdNet() transportNet, err := conn.newStdNet()
if err != nil { if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
agentConfig := &ice.AgentConfig{ agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled, MulticastDNSMode: ice.MulticastDNSModeDisabled,
@ -179,7 +155,7 @@ func (conn *Conn) reCreateAgent() error {
Urls: conn.config.StunTurn, Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout, FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux, UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx, UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs, NAT1To1IPs: conn.config.NATExternalIPs,
@ -319,7 +295,7 @@ func (conn *Conn) Open() error {
return err return err
} }
if conn.proxy.Type() == proxy.TypeNoProxy { if conn.proxy.Type() == proxy.TypeDirectNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String()) host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String()) rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection // direct Wireguard connection
@ -341,29 +317,62 @@ func (conn *Conn) Open() error {
// useProxy determines whether a direct connection (without a go proxy) is possible // useProxy determines whether a direct connection (without a go proxy) is possible
// //
// There are 2 cases: // There are 3 cases:
// //
// * When neither candidate is from hard nat and one of the peers has a public IP // * When neither candidate is from hard nat and one of the peers has a public IP
// //
// * both peers are in the same private network // * both peers are in the same private network
// //
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
//
// Please note, that this check happens when peers were already able to ping each other using ICE layer. // Please note, that this check happens when peers were already able to ping each other using ICE layer.
func shouldUseProxy(pair *ice.CandidatePair) bool { func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind {
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
return false
}
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) { if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
return false return false
} }
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) { if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
return false return false
} }
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) { if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
return false
}
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
return false return false
} }
return true return true
} }
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIP := net.ParseIP(pair.Local.Address())
remoteIP := net.ParseIP(pair.Remote.Address())
if localIP == nil || remoteIP == nil {
return false
}
// only consider /16 networks
mask := net.IPMask{255, 255, 0, 0}
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isHardNATCandidate(candidate ice.Candidate) bool { func isHardNATCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
} }
@ -376,9 +385,13 @@ func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address()) return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
} }
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
}
func isPublicIP(address string) bool { func isPublicIP(address string) bool {
ip := net.ParseIP(address) ip := net.ParseIP(address)
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() { if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
return false return false
} }
return true return true
@ -412,7 +425,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeNoProxy peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@ -423,8 +436,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
} }
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy { func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
useProxy := shouldUseProxy(pair)
localDirectMode := !useProxy localDirectMode := !useProxy
remoteDirectMode := localDirectMode remoteDirectMode := localDirectMode
@ -434,13 +446,16 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
remoteDirectMode = conn.receiveRemoteDirectMode() remoteDirectMode = conn.receiveRemoteDirectMode()
} }
if conn.config.UserspaceBind && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig)
}
if localDirectMode && remoteDirectMode { if localDirectMode && remoteDirectMode {
log.Debugf("using WireGuard direct mode with peer %s", conn.config.Key) return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
return proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
} }
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key) log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
return proxy.NewWireguardProxy(conn.config.ProxyConfig) return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
} }
func (conn *Conn) sendLocalDirectMode(localMode bool) { func (conn *Conn) sendLocalDirectMode(localMode bool) {

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -28,7 +30,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts", ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"} "Tailscale", "tailscale"}
filter := interfaceFilter(ignore) filter := stdnet.InterfaceFilter(ignore)
for _, s := range ignore { for _, s := range ignore {
assert.Equal(t, filter(s), false) assert.Equal(t, filter(s), false)
@ -208,6 +210,7 @@ func TestConn_ShouldUseProxy(t *testing.T) {
return ice.CandidateTypeHost return ice.CandidateTypeHost
}, },
} }
srflxCandidate := &mockICECandidate{ srflxCandidate := &mockICECandidate{
AddressFunc: func() string { AddressFunc: func() string {
return "1.1.1.1" return "1.1.1.1"
@ -320,11 +323,47 @@ func TestConn_ShouldUseProxy(t *testing.T) {
}, },
expected: false, expected: false,
}, },
{
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: false,
},
{
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: true,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
result := shouldUseProxy(testCase.candatePair) result := shouldUseProxy(testCase.candatePair, false)
if result != testCase.expected { if result != testCase.expected {
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result) t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
} }
@ -365,7 +404,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: true, inputRemoteModeMessage: true,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy", name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
@ -375,7 +414,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy", name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
@ -385,7 +424,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: false, inputDirectModeSupport: false,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy", name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
@ -395,7 +434,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: false, inputDirectModeSupport: false,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeNoProxy, expected: proxy.TypeDirectNoProxy,
}, },
{ {
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy", name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
@ -405,7 +444,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: true, inputRemoteModeMessage: true,
expected: proxy.TypeNoProxy, expected: proxy.TypeDirectNoProxy,
}, },
} }
for _, testCase := range testCases { for _, testCase := range testCases {

View File

@ -3,9 +3,9 @@
package peer package peer
import ( import (
"github.com/pion/transport/v2/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (conn *Conn) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet() return stdnet.NewNet(conn.config.InterfaceBlackList)
} }

View File

@ -3,5 +3,5 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.iFaceDiscover) return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
} }

View File

@ -0,0 +1,57 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
// Type returns the type of this proxy
func (p *DirectNoProxy) Type() Type {
return TypeDirectNoProxy
}

View File

@ -5,24 +5,18 @@ import (
"net" "net"
) )
// NoProxy is used when there is no need for a proxy between ICE and Wireguard. // NoProxy is used just to configure WireGuard without any local proxy in between.
// This is possible in either of these cases: // Used when the WireGuard interface is userspace and uses bind.ICEBind
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
type NoProxy struct { type NoProxy struct {
config Config config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
} }
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port // NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config, remoteWgPort int) *NoProxy { func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort} return &NoProxy{config: config}
} }
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error { func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey) err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil { if err != nil {
@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
return nil return nil
} }
// Start just updates Wireguard peer with the remote IP and default Wireguard port // Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error { func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey) log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String()) addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil { if err != nil {
return err return err
} }
addr.Port = p.RemoteWgListenPort return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey) addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
} }
func (p *NoProxy) Type() Type { func (p *NoProxy) Type() Type {

View File

@ -13,9 +13,10 @@ const DefaultWgKeepAlive = 25 * time.Second
type Type string type Type string
const ( const (
TypeNoProxy Type = "NoProxy" TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireguard Type = "Wireguard" TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy" TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
) )
type Config struct { type Config struct {

View File

@ -6,8 +6,8 @@ import (
"net" "net"
) )
// WireguardProxy proxies // WireGuardProxy proxies
type WireguardProxy struct { type WireGuardProxy struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@ -17,13 +17,13 @@ type WireguardProxy struct {
localConn net.Conn localConn net.Conn
} }
func NewWireguardProxy(config Config) *WireguardProxy { func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireguardProxy{config: config} p := &WireGuardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background()) p.ctx, p.cancel = context.WithCancel(context.Background())
return p return p
} }
func (p *WireguardProxy) updateEndpoint() error { func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil { if err != nil {
return err return err
@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
return nil return nil
} }
func (p *WireguardProxy) Start(remoteConn net.Conn) error { func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn p.remoteConn = remoteConn
var err error var err error
@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
return nil return nil
} }
func (p *WireguardProxy) Close() error { func (p *WireGuardProxy) Close() error {
p.cancel() p.cancel()
if c := p.localConn; c != nil { if c := p.localConn; c != nil {
err := p.localConn.Close() err := p.localConn.Close()
@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer // proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks // blocks
func (p *WireguardProxy) proxyToRemote() { func (p *WireGuardProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard // proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks // blocks
func (p *WireguardProxy) proxyToLocal() { func (p *WireGuardProxy) proxyToLocal() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
} }
} }
func (p *WireguardProxy) Type() Type { func (p *WireGuardProxy) Type() Type {
return TypeWireguard return TypeWireGuard
} }

View File

@ -3,6 +3,7 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip" "net/netip"
"runtime" "runtime"
"testing" "testing"
@ -391,7 +392,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@ -3,6 +3,7 @@ package routemanager
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
"net/netip" "net/netip"
@ -32,7 +33,11 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@ -0,0 +1,14 @@
package stdnet
import "github.com/pion/transport/v2"
// ExternalIFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type ExternalIFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}
type iFaceDiscover interface {
iFaces() ([]*transport.Interface, error)
}

View File

@ -0,0 +1,95 @@
package stdnet
import (
"fmt"
"net"
"strings"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
type mobileIFaceDiscover struct {
externalDiscover ExternalIFaceDiscover
}
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
return &mobileIFaceDiscover{
externalDiscover: externalDiscover,
}
}
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
ifacesString, err := m.externalDiscover.IFaces()
if err != nil {
return nil, err
}
interfaces := m.parseInterfacesString(ifacesString)
return interfaces, nil
}
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
ifs := []*transport.Interface{}
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
}
fields := strings.Split(iface, "|")
if len(fields) != 2 {
log.Warnf("parseInterfacesString: unable to split %q", iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
}
}
return ifs
}

View File

@ -35,7 +35,9 @@ func Test_parseInterfacesString(t *testing.T) {
d.multicast, d.multicast,
d.addr) d.addr)
} }
nets := parseInterfacesString(exampleString)
d := mobileIFaceDiscover{}
nets := d.parseInterfacesString(exampleString)
if len(nets) == 0 { if len(nets) == 0 {
t.Fatalf("failed to parse interfaces") t.Fatalf("failed to parse interfaces")
} }

View File

@ -0,0 +1,36 @@
package stdnet
import (
"net"
"github.com/pion/transport/v2"
)
type pionDiscover struct {
}
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
ifs := []*transport.Interface{}
oifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, oif := range oifs {
ifc := transport.NewInterface(oif)
addrs, err := oif.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ifc.AddAddress(addr)
}
ifs = append(ifs, ifc)
}
return ifs, nil
}

View File

@ -0,0 +1,40 @@
package stdnet
import (
"strings"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
)
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them.
func InterfaceFilter(disallowList []string) func(string) bool {
return func(iFace string) bool {
if strings.HasPrefix(iFace, "lo") {
// hardcoded loopback check to support already installed agents
return false
}
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}

View File

@ -1,8 +0,0 @@
package stdnet
// IFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type IFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}

View File

@ -5,12 +5,9 @@ package stdnet
import ( import (
"fmt" "fmt"
"net"
"strings"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet" "github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
) )
// Net is an implementation of the net.Net interface // Net is an implementation of the net.Net interface
@ -18,24 +15,40 @@ import (
type Net struct { type Net struct {
stdnet.Net stdnet.Net
interfaces []*transport.Interface interfaces []*transport.Interface
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
} }
// NewNet creates a new StdNet instance. // NewNet creates a new StdNet instance.
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) { func NewNet(disallowList []string) (*Net, error) {
n := &Net{} n := &Net{
iFaceDiscover: pionDiscover{},
return n, n.UpdateInterfaces(iFaceDiscover) interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
} }
// UpdateInterfaces updates the internal list of network interfaces // UpdateInterfaces updates the internal list of network interfaces
// and associated addresses. // and associated addresses filtering them by name.
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error { // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
ifacesString, err := iFaceDiscover.IFaces() // wasn't specified.
func (n *Net) UpdateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces()
if err != nil { if err != nil {
return err return err
} }
n.interfaces = parseInterfacesString(ifacesString) n.interfaces = n.filterInterfaces(allIfaces)
return err return nil
} }
// Interfaces returns a slice of interfaces which are available on the // Interfaces returns a slice of interfaces which are available on the
@ -70,68 +83,15 @@ func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
} }
func parseInterfacesString(interfaces string) []*transport.Interface { func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
ifs := []*transport.Interface{} if n.interfaceFilter == nil {
return interfaces
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
} }
result := []*transport.Interface{}
fields := strings.Split(iface, "|") for _, iface := range interfaces {
if len(fields) != 2 { if n.interfaceFilter(iface.Name) {
log.Warnf("parseInterfacesString: unable to split %q", iface) result = append(result, iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
} }
} }
return ifs return result
} }

View File

@ -6,4 +6,4 @@
#define EXPAND(x) STRINGIZE(x) #define EXPAND(x) STRINGIZE(x)
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/netbird.ico 7 ICON ui/netbird.ico
wireguard.dll RCDATA wireguard.dll wintun.dll RCDATA wintun.dll

View File

@ -1,27 +0,0 @@
#!/bin/bash
ldir=$PWD
tmp_dir_path=$ldir/.distfiles
winnt=wireguard-nt.zip
download_file_path=$tmp_dir_path/$winnt
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
function resources_windows(){
cmd=$1
arch=$2
out=$3
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
}
mkdir -p $tmp_dir_path
curl -L#o $download_file_path.unverified $download_url
echo "$download_sha $download_file_path.unverified" | sha256sum -c
mv $download_file_path.unverified $download_file_path
mkdir -p .deps
unzip $download_file_path -d $tmp_dir_path
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso

7
go.mod
View File

@ -19,7 +19,7 @@ require (
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0 golang.org/x/crypto v0.7.0
golang.org/x/sys v0.6.0 golang.org/x/sys v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.52.3 google.golang.org/grpc v1.52.3
@ -48,6 +48,8 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/open-policy-agent/opa v0.49.0 github.com/open-policy-agent/opa v0.49.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/stun v0.4.0
github.com/pion/transport/v2 v2.0.2 github.com/pion/transport/v2 v2.0.2
github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
@ -103,10 +105,8 @@ require (
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.6 // indirect github.com/pion/dtls/v2 v2.2.6 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.7 // indirect github.com/pion/mdns v0.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.4.0 // indirect
github.com/pion/turn/v2 v2.1.0 // indirect github.com/pion/turn/v2 v2.1.0 // indirect
github.com/pion/udp/v2 v2.0.1 // indirect github.com/pion/udp/v2 v2.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
@ -131,7 +131,6 @@ require (
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

5
go.sum
View File

@ -881,13 +881,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ= golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=

208
iface/bind/bind.go Normal file
View File

@ -0,0 +1,208 @@
package bind
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"syscall"
"github.com/pion/stun"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
)
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
type ICEBind struct {
// below fields, initialized on open
ipv4 net.PacketConn
udpMux *UniversalUDPMuxDefault
// below are fields initialized on creation
transportNet transport.Net
mu sync.Mutex
}
// NewICEBind create a new instance of ICEBind with a given transportNet function.
// The transportNet can be nil.
func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{
transportNet: transportNet,
mu: sync.Mutex{},
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
}
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.ipv4 != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
var err error
b.ipv4, _, err = listenNet("udp4", int(uport))
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.ipv4),
},
portAddr.Port(), nil
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
lAddr := c.LocalAddr()
uAddr, err := net.ResolveUDPAddr(
lAddr.Network(),
lAddr.String(),
)
if err != nil {
return nil, 0, err
}
return c, uAddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff) {
// WireGuard traffic
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
// Close closes the WireGuard socket and UDPMux
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.ipv4 != nil {
c := b.ipv4
b.ipv4 = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
// Send bytes to the remote endpoint (peer)
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
IP: addrPort.Addr().AsSlice(),
Port: int(addrPort.Port()),
Zone: addrPort.Addr().Zone(),
})
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]conn.Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = conn.Endpoint(conn.StdNetEndpoint(ap))
m[ap] = e
}
return e
}

445
iface/bind/udp_mux.go Normal file
View File

@ -0,0 +1,445 @@
package bind
import (
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/transport/v2"
)
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
// for UDP connection listen at unspecified address
localAddrsForUnspecified []net.Addr
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
// Required for gathering local addresses
// in case a un UDPConn is passed which does not
// bind to a specific local address.
Net transport.Net
InterfaceFilter func(interfaceName string) bool
}
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := n.Interfaces()
if err != nil {
return ips, err
}
var IPv4Requested, IPv6Requested bool
for _, typ := range networkTypes {
if typ.IsIPv4() {
IPv4Requested = true
}
if typ.IsIPv6() {
IPv6Requested = true
}
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch addr := addr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}
if ipv4 := ip.To4(); ipv4 == nil {
if !IPv6Requested {
continue
} else if !isSupportedIPv6(ip) {
continue
}
} else if !IPv4Requested {
continue
}
if ipFilter != nil && !ipFilter(ip) {
continue
}
ips = append(ips, ip)
}
}
return ips, nil
}
// The conditions of invalidation written below are defined in
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
func isSupportedIPv6(ip net.IP) bool {
if len(ip) != net.IPv6len ||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() {
return false
}
return true
}
func isZeros(ip net.IP) bool {
for i := 0; i < len(ip); i++ {
if ip[i] != 0 {
return false
}
}
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.connsIPv4 {
_ = c.Close()
}
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
_ = m.params.UDPConn.Close()
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buf []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buf: make([]byte, size),
}
}

View File

@ -0,0 +1,254 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
*/
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2"
)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, fmt.Errorf("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
// All other STUN packets will be forwarded to the UDPMux
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return fmt.Errorf("no XOR address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, fmt.Errorf("no XOR address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@ -0,0 +1,233 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, rAddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, rAddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
close(c.closedChan)
})
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@ -5,6 +5,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -19,6 +21,17 @@ type WGIface struct {
tun *tunDevice tun *tunDevice
configurer wGConfigurer configurer wGConfigurer
mu sync.Mutex mu sync.Mutex
userspaceBind bool
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
}
// GetBind returns a userspace implementation of WireGuard Bind interface
func (w *WGIface) GetBind() *bind.ICEBind {
return w.tun.iceBind
} }
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
@ -26,7 +39,7 @@ type WGIface struct {
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create Wireguard interface %s", w.tun.DeviceName()) log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create() return w.tun.Create()
} }

View File

@ -1,22 +1,28 @@
package iface package iface
import "sync" import (
"sync"
// NewWGIFace Creates a new Wireguard interface instance "github.com/pion/transport/v2"
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { )
wgIface := &WGIface{
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return wgIface, err return wgIFace, err
} }
tun := newTunDevice(wgAddress, mtu, tunAdapter) tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet)
wgIface.tun = tun wgIFace.tun = tun
wgIface.configurer = newWGConfigurer(tun) wgIFace.configurer = newWGConfigurer(tun)
return wgIface, nil wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
} }

View File

@ -2,21 +2,26 @@
package iface package iface
import "sync" import (
"sync"
// NewWGIFace Creates a new Wireguard interface instance "github.com/pion/transport/v2"
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { )
wgIface := &WGIface{
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return wgIface, err return wgIFace, err
} }
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu) wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIface.configurer = newWGConfigurer(ifaceName) wgIFace.configurer = newWGConfigurer(iFaceName)
return wgIface, nil wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
} }

View File

@ -2,13 +2,15 @@ package iface
import ( import (
"fmt" "fmt"
"net"
"testing"
"time"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"testing"
"time"
) )
// keep darwin compability // keep darwin compability
@ -32,7 +34,12 @@ func init() {
func TestWGIface_UpdateAddr(t *testing.T) { func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8" addr := "100.64.0.1/8"
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -92,7 +99,11 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -121,7 +132,11 @@ func Test_CreateInterface(t *testing.T) {
func Test_Close(t *testing.T) { func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -149,7 +164,11 @@ func Test_Close(t *testing.T) {
func Test_ConfigureInterface(t *testing.T) { func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30" wgIP := "10.99.99.5/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -196,7 +215,11 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) { func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30" wgIP := "10.99.99.9/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -255,7 +278,11 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30" wgIP := "10.99.99.13/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -304,8 +331,11 @@ func Test_ConnectPeers(t *testing.T) {
peer2Key, _ := wgtypes.GeneratePrivateKey() peer2Key, _ := wgtypes.GeneratePrivateKey()
keepAlive := 1 * time.Second keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil) if err != nil {
t.Fatal(err)
}
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -322,7 +352,11 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil) newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -3,7 +3,7 @@
package iface package iface
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireguardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
return false return false
} }

View File

@ -7,9 +7,6 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io" "io"
"io/fs" "io/fs"
"math" "math"
@ -17,6 +14,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
) )
// Holds logic to check existence of kernel modules used by wireguard interfaces // Holds logic to check existence of kernel modules used by wireguard interfaces
@ -33,6 +34,7 @@ const (
loading loading
live live
inuse inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
) )
type module struct { type module struct {
@ -81,9 +83,15 @@ func tunModuleIsLoaded() bool {
return tunLoaded return tunLoaded
} }
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireguardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true return true
} }
@ -96,7 +104,7 @@ func WireguardModuleIsLoaded() bool {
return loaded return loaded
} }
func canCreateFakeWireguardInterface() bool { func canCreateFakeWireGuardInterface() bool {
link := newWGLink("mustnotexist") link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid // We willingly try to create a device with an invalid

View File

@ -3,13 +3,14 @@ package iface
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
func TestGetModuleDependencies(t *testing.T) { func TestGetModuleDependencies(t *testing.T) {

View File

@ -3,9 +3,12 @@ package iface
import ( import (
"net" "net"
"github.com/pion/transport/v2"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@ -20,13 +23,15 @@ type tunDevice struct {
name string name string
device *device.Device device *device.Device
uapi net.Listener uapi net.Listener
iceBind *bind.ICEBind
} }
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice { func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
address: address, address: address,
mtu: mtu, mtu: mtu,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet),
} }
} }
@ -46,7 +51,7 @@ func (t *tunDevice) Create() error {
t.name = name t.name = name
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device.DisableSomeRoamingForBrokenMobileSemantics() t.device.DisableSomeRoamingForBrokenMobileSemantics()
log.Debugf("create uapi") log.Debugf("create uapi")

View File

@ -11,7 +11,7 @@ import (
) )
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
if WireguardModuleIsLoaded() { if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard") log.Info("using kernel WireGuard")
return c.createWithKernel() return c.createWithKernel()
} }
@ -30,7 +30,7 @@ func (c *tunDevice) Create() error {
} }
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module. // createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
// Works for Linux and offers much better network performance // Works for Linux and offers much better network performance
func (c *tunDevice) createWithKernel() error { func (c *tunDevice) createWithKernel() error {

View File

@ -6,10 +6,13 @@ import (
"net" "net"
"os" "os"
log "github.com/sirupsen/logrus" "github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@ -18,13 +21,18 @@ type tunDevice struct {
address WGAddress address WGAddress
mtu int mtu int
netInterface NetInterface netInterface NetInterface
iceBind *bind.ICEBind
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
name: name, name: name,
address: address, address: address,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
} }
} }
@ -42,23 +50,38 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
if c.netInterface == nil {
return nil select {
case c.close <- struct{}{}:
default:
} }
err := c.netInterface.Close()
if err != nil { var err1, err2, err3 error
return err if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
} }
sockPath := "/var/run/wireguard/" + c.name + ".sock" sockPath := "/var/run/wireguard/" + c.name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil { if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath) statErr = os.Remove(sockPath)
if statErr != nil { if statErr != nil {
return statErr err3 = statErr
} }
} }
return nil if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation // createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
@ -69,26 +92,36 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
} }
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDevice.Up() err = tunDev.Up()
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
// todo: after this line in case of error close the tunSock c.uapi, err = c.getUAPI(c.name)
uapi, err := c.getUAPI(c.name)
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
go func() { go func() {
for { for {
uapiConn, uapiErr := uapi.Accept() select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil { if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr) log.Traceln("uapi Accept failed with error: ", uapiErr)
continue continue
} }
go tunDevice.IpcHandle(uapiConn) go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
} }
}() }()

View File

@ -4,24 +4,39 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/iface/bind"
) )
type tunDevice struct { type tunDevice struct {
name string name string
address WGAddress address WGAddress
netInterface NetInterface netInterface NetInterface
iceBind *bind.ICEBind
mtu int
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{name: name, address: address} return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
}
} }
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
var err error var err error
c.netInterface, err = c.createAdapter() c.netInterface, err = c.createWithUserspace()
if err != nil { if err != nil {
return err return err
} }
@ -29,6 +44,51 @@ func (c *tunDevice) Create() error {
return c.assignAddr() return c.assignAddr()
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu)
if err != nil {
return nil, err
}
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()
if err != nil {
_ = tunIface.Close()
return nil, err
}
c.uapi, err = c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
}
}()
log.Debugln("UAPI listener started")
return tunIface, nil
}
func (c *tunDevice) UpdateAddr(address WGAddress) error { func (c *tunDevice) UpdateAddr(address WGAddress) error {
c.address = address c.address = address
return c.assignAddr() return c.assignAddr()
@ -43,19 +103,33 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
if c.netInterface == nil { select {
return nil case c.close <- struct{}{}:
default:
} }
return c.netInterface.Close() var err1, err2 error
if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
}
if err1 != nil {
return err1
}
return err2
} }
func (c *tunDevice) getInterfaceGUIDString() (string, error) { func (c *tunDevice) getInterfaceGUIDString() (string, error) {
if c.netInterface == nil { if c.netInterface == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")
} }
windowsDevice := c.netInterface.(*driver.Adapter) windowsDevice := c.netInterface.(*tun.NativeTun)
luid := windowsDevice.LUID() luid := winipcfg.LUID(windowsDevice.LUID())
guid, err := luid.GUID() guid, err := luid.GUID()
if err != nil { if err != nil {
return "", err return "", err
@ -63,31 +137,15 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) {
return guid.String(), nil return guid.String(), nil
} }
func (c *tunDevice) createAdapter() (NetInterface, error) {
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
if err != nil {
err = fmt.Errorf("error creating adapter: %w", err)
return nil, err
}
err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil {
return adapter, err
}
state, _ := adapter.LUID().GUID()
log.Debugln("device guid: ", state.String())
return adapter, nil
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (c *tunDevice) assignAddr() error { func (c *tunDevice) assignAddr() error {
luid := c.netInterface.(*driver.Adapter).LUID() tunDev := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(tunDev.LUID())
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name) log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}}) return luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
if err != nil { }
return err
} // getUAPI returns a Listener
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
return nil return ipc.UAPIListen(iface)
} }