diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 67d411df5..5f67a6ece 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/golang/mock/gomock" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -250,11 +251,12 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { + privKey, _ := wgtypes.GenerateKey() newNet, err := stdnet.NewNet(nil) if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet) + wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -331,7 +333,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet) + privKey, _ := wgtypes.GeneratePrivateKey() + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) if err != nil { t.Errorf("build interface wireguard: %v", err) return @@ -782,7 +785,8 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet) + privKey, _ := wgtypes.GeneratePrivateKey() + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) if err != nil { t.Fatalf("build interface wireguard: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index d811ad48c..3c421aabd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -3,7 +3,6 @@ package internal import ( "context" "fmt" - "io" "math/rand" "net" "net/netip" @@ -32,7 +31,6 @@ import ( mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/sharedsock" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/util" @@ -107,8 +105,7 @@ type Engine struct { wgInterface *iface.WGIface wgProxyFactory *wgproxy.Factory - udpMux *bind.UniversalUDPMuxDefault - udpMuxConn io.Closer + udpMux *bind.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -181,66 +178,26 @@ func (e *Engine) Start() error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - wgIFaceName := e.config.WgIfaceName - wgAddr := e.config.WgAddr - myPrivateKey := e.config.WgPrivateKey - var err error - transportNet, err := e.newStdNet() + wgIface, err := e.newWgIface() if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.TunAdapter, transportNet) - if err != nil { - log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error()) + log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error()) return err } + e.wgInterface = wgIface - var routes []*route.Route - - switch runtime.GOOS { - case "android": - var dnsConfig *nbdns.Config - routes, dnsConfig, err = e.readInitialSettings() - if err != nil { - return err - } - if e.dnsServer == nil { - e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) - go e.mobileDep.DnsReadyListener.OnReady() - } - case "ios": - if e.dnsServer == nil { - e.dnsServer = dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager) - } - default: - if e.dnsServer == nil { - e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) - if err != nil { - e.close() - return err - } - } + initialRoutes, dnsServer, err := e.newDnsServer() + if err != nil { + e.close() + return err } + e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - switch runtime.GOOS { - case "android": - err = e.wgInterface.CreateOnAndroid(iface.MobileIFaceArguments{ - Routes: e.routeManager.InitialRouteRange(), - Dns: e.dnsServer.DnsIP(), - SearchDomains: e.dnsServer.SearchDomains(), - }) - case "ios": - e.mobileDep.NetworkChangeListener.SetInterfaceIP(wgAddr) - err = e.wgInterface.CreateOniOS(e.mobileDep.FileDescriptor) - default: - err = e.wgInterface.Create() - } + err = e.wgInterfaceCreate() if err != nil { - log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error()) + log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) e.close() return err } @@ -258,33 +215,13 @@ func (e *Engine) Start() error { } } - err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort) + e.udpMux, err = e.wgInterface.Up() if err != nil { - log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error()) + log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) e.close() return err } - if e.wgInterface.IsUserspaceBind() { - iceBind := e.wgInterface.GetBind() - udpMux, err := iceBind.GetICEMux() - if err != nil { - e.close() - return err - } - e.udpMux = udpMux - log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String()) - } else { - rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter()) - if err != nil { - return err - } - mux := bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: rawSock, Net: transportNet}) - go mux.ReadFromConn(e.ctx) - e.udpMuxConn = rawSock - e.udpMux = mux - } - if e.firewall != nil { e.acl = acl.NewDefaultManager(e.firewall) } @@ -1042,18 +979,6 @@ func (e *Engine) close() { } } - if e.udpMux != nil { - if err := e.udpMux.Close(); err != nil { - log.Debugf("close udp mux: %v", err) - } - } - - if e.udpMuxConn != nil { - if err := e.udpMuxConn.Close(); err != nil { - log.Debugf("close udp mux connection: %v", err) - } - } - if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { @@ -1087,6 +1012,68 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { return routes, &dnsCfg, nil } +func (e *Engine) newWgIface() (*iface.WGIface, error) { + transportNet, err := e.newStdNet() + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + var mArgs *iface.MobileIFaceArguments + switch runtime.GOOS { + case "android": + mArgs = &iface.MobileIFaceArguments{ + TunAdapter: e.mobileDep.TunAdapter, + TunFd: int(e.mobileDep.FileDescriptor), + } + case "ios": + mArgs = &iface.MobileIFaceArguments{ + TunFd: int(e.mobileDep.FileDescriptor), + } + default: + } + + return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs) +} + +func (e *Engine) wgInterfaceCreate() (err error) { + switch runtime.GOOS { + case "android": + err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP(), e.dnsServer.SearchDomains()) + case "ios": + e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) + err = e.wgInterface.Create() + default: + err = e.wgInterface.Create() + } + return err +} + +func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { + // due to tests where we are using a mocked version of the DNS server + if e.dnsServer != nil { + return nil, e.dnsServer, nil + } + switch runtime.GOOS { + case "android": + routes, dnsConfig, err := e.readInitialSettings() + if err != nil { + return nil, nil, err + } + dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) + go e.mobileDep.DnsReadyListener.OnReady() + return routes, dnsServer, nil + case "ios": + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager) + return nil, dnsServer, nil + default: + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) + if err != nil { + return nil, nil, err + } + return nil, dnsServer, nil + } +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2de9b29f0..5dfc171a6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -213,7 +213,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -567,7 +567,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -736,7 +736,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 1a2a4c2b2..2355c67c3 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -9,11 +9,14 @@ import ( // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { + // Android only TunAdapter iface.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string DnsReadyListener dns.ReadyListener - DnsManager dns.IosDnsManager - FileDescriptor int32 + + // iOS only + DnsManager dns.IosDnsManager + FileDescriptor int32 } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 1aa58c16b..2e5cf6649 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/pion/transport/v3/stdnet" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stretchr/testify/require" @@ -399,12 +400,12 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index f43a88eec..6f32d9634 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -14,6 +14,7 @@ import ( "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) @@ -41,11 +42,12 @@ func TestAddRemoveRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -175,11 +177,12 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { log.SetOutput(os.Stderr) }() t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/go.mod b/go.mod index 0ac35b490..1c0cfc0c0 100644 --- a/go.mod +++ b/go.mod @@ -57,12 +57,12 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 github.com/pion/stun/v2 v2.0.0 - github.com/pion/transport/v2 v2.2.1 github.com/pion/transport/v3 v3.0.1 github.com/prometheus/client_golang v1.14.0 github.com/rs/xid v1.3.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/stretchr/testify v1.8.4 + github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.3 go.opentelemetry.io/otel v1.11.1 go.opentelemetry.io/otel/exporters/prometheus v0.33.0 @@ -110,6 +110,7 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/google/btree v1.0.1 // indirect github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect @@ -128,6 +129,7 @@ require ( github.com/pion/dtls/v2 v2.2.7 // indirect github.com/pion/mdns v0.0.9 // indirect github.com/pion/randutil v0.1.0 // indirect + github.com/pion/transport/v2 v2.2.1 // indirect github.com/pion/turn/v3 v3.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect @@ -154,6 +156,7 @@ require ( gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 // indirect honnef.co/go/tools v0.2.2 // indirect k8s.io/apimachinery v0.23.5 // indirect ) diff --git a/go.sum b/go.sum index 62ae3ed4a..26322b1ad 100644 --- a/go.sum +++ b/go.sum @@ -665,6 +665,8 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0= +github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= diff --git a/iface/iface.go b/iface/iface.go index 55891d047..0e6da2547 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/iface/bind" - log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/iface/bind" ) const ( @@ -19,11 +19,12 @@ const ( // WGIface represents a interface instance type WGIface struct { - tun *tunDevice - configurer wGConfigurer - mu sync.Mutex + tun wgTunDevice userspaceBind bool - filter PacketFilter + mu sync.Mutex + + configurer wgConfigurer + filter PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -31,11 +32,6 @@ func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind } -// GetBind returns a userspace implementation of WireGuard Bind interface -func (w *WGIface) GetBind() *bind.ICEBind { - return w.tun.iceBind -} - // Name returns the interface name func (w *WGIface) Name() string { return w.tun.DeviceName() @@ -46,13 +42,13 @@ func (w *WGIface) Address() WGAddress { return w.tun.WgAddress() } -// Configure configures a Wireguard interface +// Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Configure(privateKey string, port int) error { +func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() - log.Debugf("configuring Wireguard interface %s", w.tun.DeviceName()) - return w.configurer.configureInterface(privateKey, port) + + return w.tun.Up() } // UpdateAddr updates address of the interface @@ -74,7 +70,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D w.mu.Lock() defer w.mu.Unlock() - log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint) + log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } @@ -117,14 +113,14 @@ func (w *WGIface) SetFilter(filter PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() - if w.tun.wrapper == nil { + if w.tun.Wrapper() == nil { return fmt.Errorf("userspace packet filtering not handled on this device") } w.filter = filter - w.filter.SetNetwork(w.tun.address.Network) + w.filter.SetNetwork(w.tun.WgAddress().Network) - w.tun.wrapper.SetFilter(filter) + w.tun.Wrapper().SetFilter(filter) return nil } @@ -141,5 +137,5 @@ func (w *WGIface) GetDevice() *DeviceWrapper { w.mu.Lock() defer w.mu.Unlock() - return w.tun.wrapper + return w.tun.Wrapper() } diff --git a/iface/iface_android.go b/iface/iface_android.go index 4803abfe3..d1876e495 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -2,47 +2,39 @@ package iface import ( "fmt" - "sync" "github.com/pion/transport/v3" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { - wgIFace := &WGIface{ - mu: sync.Mutex{}, - } - +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { - return wgIFace, err + return nil, err } - tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet) - wgIFace.tun = tun - - wgIFace.configurer = newWGConfigurer(tun) - - wgIFace.userspaceBind = !WireGuardModuleIsLoaded() - + wgIFace := &WGIface{ + tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter), + userspaceBind: true, + } return wgIFace, nil } // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. -func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { +func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Create(mIFaceArgs) -} -// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOniOS(tunFd int32) error { - return fmt.Errorf("this function has not implemented on mobile") + cfgr, err := w.tun.Create(routes, dns, searchDomains) + if err != nil { + return err + } + w.configurer = cfgr + return nil } // Create this function make sense on mobile only func (w *WGIface) Create() error { - return fmt.Errorf("this function has not implemented on mobile") + return fmt.Errorf("this function has not implemented on this platform") } diff --git a/iface/iface_create.go b/iface/iface_create.go new file mode 100644 index 000000000..86c3f320f --- /dev/null +++ b/iface/iface_create.go @@ -0,0 +1,20 @@ +//go:build !android +// +build !android + +package iface + +// Create creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +// this function is different on Android +func (w *WGIface) Create() error { + w.mu.Lock() + defer w.mu.Unlock() + + cfgr, err := w.tun.Create() + if err != nil { + return err + } + + w.configurer = cfgr + return nil +} diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go new file mode 100644 index 000000000..4d62c6af6 --- /dev/null +++ b/iface/iface_darwin.go @@ -0,0 +1,38 @@ +//go:build !ios +// +build !ios + +package iface + +import ( + "fmt" + + "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/netstack" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { + wgAddress, err := parseWGAddress(address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{ + userspaceBind: true, + } + + if netstack.IsEnabled() { + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + return wgIFace, nil + } + + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + + return wgIFace, nil +} + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/iface/iface_ios.go b/iface/iface_ios.go index dd68d7792..b22e1a6a4 100644 --- a/iface/iface_ios.go +++ b/iface/iface_ios.go @@ -5,47 +5,25 @@ package iface import ( "fmt" - "sync" - "github.com/pion/transport/v2" + "github.com/pion/transport/v3" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { - wgIFace := &WGIface{ - mu: sync.Mutex{}, - } - +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { - return wgIFace, err + return nil, err + } + wgIFace := &WGIface{ + tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd), + userspaceBind: true, } - - tun := newTunDevice(ifaceName, wgAddress, mtu, tunAdapter, transportNet) - wgIFace.tun = tun - - wgIFace.configurer = newWGConfigurer(tun) - - wgIFace.userspaceBind = !WireGuardModuleIsLoaded() - return wgIFace, nil } -// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOniOS(tunFd int32) error { - w.mu.Lock() - defer w.mu.Unlock() - return w.tun.Create(tunFd) -} - // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. -func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { - return fmt.Errorf("this function has not implemented on mobile") -} - -// Create this function make sense on mobile only -func (w *WGIface) Create() error { - return fmt.Errorf("this function has not implemented on mobile") +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") } diff --git a/iface/iface_linux.go b/iface/iface_linux.go new file mode 100644 index 000000000..73606a25c --- /dev/null +++ b/iface/iface_linux.go @@ -0,0 +1,48 @@ +//go:build !android +// +build !android + +package iface + +import ( + "fmt" + + "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/netstack" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { + wgAddress, err := parseWGAddress(address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + // move the kernel/usp/netstack preference evaluation to upper layer + if netstack.IsEnabled() { + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.userspaceBind = true + return wgIFace, nil + } + + if WireGuardModuleIsLoaded() { + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.userspaceBind = false + return wgIFace, nil + } + + if !tunModuleIsLoaded() { + return nil, fmt.Errorf("couldn't check or load tun module") + } + wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.userspaceBind = true + return wgIFace, nil + +} + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/iface/iface_nonandroid.go b/iface/iface_nonandroid.go deleted file mode 100644 index f6862f590..000000000 --- a/iface/iface_nonandroid.go +++ /dev/null @@ -1,47 +0,0 @@ -//go:build !android && !ios -// +build !android,!ios - -package iface - -import ( - "fmt" - "sync" - - "github.com/pion/transport/v3" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { - wgIFace := &WGIface{ - mu: sync.Mutex{}, - } - - wgAddress, err := parseWGAddress(address) - if err != nil { - return wgIFace, err - } - - wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet) - - wgIFace.configurer = newWGConfigurer(iFaceName) - wgIFace.userspaceBind = !WireGuardModuleIsLoaded() - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { - return fmt.Errorf("this function has not implemented on non mobile") -} - -// CreateOniOS this function make sense on mobile only -func (w *WGIface) CreateOniOS(tunFd int32) error { - return fmt.Errorf("this function has not implemented on non mobile") -} - -// Create creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) Create() error { - w.mu.Lock() - defer w.mu.Unlock() - return w.tun.Create() -} diff --git a/iface/iface_test.go b/iface/iface_test.go index 7debbe4fc..3fc250637 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -34,12 +34,13 @@ func init() { func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" + wgPort := 33100 newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet) + iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -52,12 +53,10 @@ func TestWGIface_UpdateAddr(t *testing.T) { if err != nil { t.Error(err) } + }() - port, err := getListenPortByName(ifaceName) - if err != nil { - t.Fatal(err) - } - err = iface.Configure(key, port) + + _, err = iface.Up() if err != nil { t.Fatal(err) } @@ -103,7 +102,7 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -132,11 +131,13 @@ func Test_CreateInterface(t *testing.T) { func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" + wgPort := 33100 newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) + + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -164,11 +165,12 @@ func Test_Close(t *testing.T) { func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" + wgPort := 33100 newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -183,11 +185,7 @@ func Test_ConfigureInterface(t *testing.T) { } }() - port, err := getListenPortByName(ifaceName) - if err != nil { - t.Fatal(err) - } - err = iface.Configure(key, port) + _, err = iface.Up() if err != nil { t.Fatal(err) } @@ -219,7 +217,8 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) + + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -233,11 +232,8 @@ func Test_UpdatePeer(t *testing.T) { t.Error(err) } }() - port, err := getListenPortByName(ifaceName) - if err != nil { - t.Fatal(err) - } - err = iface.Configure(key, port) + + _, err = iface.Up() if err != nil { t.Fatal(err) } @@ -251,7 +247,7 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - peer, err := iface.configurer.getPeer(ifaceName, peerPubKey) + peer, err := getPeer(ifaceName, peerPubKey) if err != nil { t.Fatal(err) } @@ -282,7 +278,8 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) + + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -296,11 +293,8 @@ func Test_RemovePeer(t *testing.T) { t.Error(err) } }() - port, err := getListenPortByName(ifaceName) - if err != nil { - t.Fatal(err) - } - err = iface.Configure(key, port) + + _, err = iface.Up() if err != nil { t.Fatal(err) } @@ -315,7 +309,8 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = iface.configurer.getPeer(ifaceName, peerPubKey) + + _, err = getPeer(ifaceName, peerPubKey) if err.Error() != "peer not found" { t.Fatal(err) } @@ -325,17 +320,20 @@ func Test_ConnectPeers(t *testing.T) { peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) peer1wgIP := "10.99.99.17/30" peer1Key, _ := wgtypes.GeneratePrivateKey() + peer1wgPort := 33100 peer2ifaceName := "utun500" peer2wgIP := "10.99.99.18/30" peer2Key, _ := wgtypes.GeneratePrivateKey() + peer2wgPort := 33200 keepAlive := 1 * time.Second newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet) + + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -343,11 +341,13 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - peer1Port, err := getListenPortByName(peer1ifaceName) + + _, err = iface1.Up() if err != nil { t.Fatal(err) } - peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1Port)) + + peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort)) if err != nil { t.Fatal(err) } @@ -356,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet) + iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil) if err != nil { t.Fatal(err) } @@ -364,11 +364,13 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - peer2Port, err := getListenPortByName(peer2ifaceName) + + _, err = iface2.Up() if err != nil { t.Fatal(err) } - peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2Port)) + + peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort)) if err != nil { t.Fatal(err) } @@ -383,15 +385,6 @@ func Test_ConnectPeers(t *testing.T) { } }() - err = iface1.Configure(peer1Key.String(), peer1Port) - if err != nil { - t.Fatal(err) - } - err = iface2.Configure(peer2Key.String(), peer2Port) - if err != nil { - t.Fatal(err) - } - err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil) if err != nil { t.Fatal(err) @@ -403,13 +396,15 @@ func Test_ConnectPeers(t *testing.T) { // todo: investigate why in some tests execution we need 30s timeout := 30 * time.Second timeoutChannel := time.After(timeout) + for { select { case <-timeoutChannel: t.Fatalf("waiting for peer handshake timeout after %s", timeout.String()) default: } - peer, gpErr := iface1.configurer.getPeer(peer1ifaceName, peer2Key.PublicKey().String()) + + peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String()) if gpErr != nil { t.Fatal(gpErr) } @@ -421,17 +416,26 @@ func Test_ConnectPeers(t *testing.T) { } -func getListenPortByName(name string) (int, error) { +func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { - return 0, err + return wgtypes.Peer{}, err } - defer wg.Close() + defer func() { + err = wg.Close() + if err != nil { + log.Errorf("got error while closing wgctl: %v", err) + } + }() - d, err := wg.Device(name) + wgDevice, err := wg.Device(ifaceName) if err != nil { - return 0, err + return wgtypes.Peer{}, err } - - return d.ListenPort, nil + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer not found") } diff --git a/iface/iface_windows.go b/iface/iface_windows.go index a67df296c..d3a16a52f 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -1,6 +1,39 @@ package iface +import ( + "fmt" + + "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/netstack" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { + wgAddress, err := parseWGAddress(address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{ + userspaceBind: true, + } + + if netstack.IsEnabled() { + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + return wgIFace, nil + } + + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + return wgIFace, nil +} + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on non mobile") +} + // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.getInterfaceGUIDString() + return w.tun.(*tunDevice).getInterfaceGUIDString() } diff --git a/iface/ipc_parser_mobile.go b/iface/ipc_parser_mobile.go deleted file mode 100644 index 7d4af8139..000000000 --- a/iface/ipc_parser_mobile.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build android || ios -// +build android ios - -package iface - -import ( - "encoding/hex" - "fmt" - "strings" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -func toWgUserspaceString(wgCfg wgtypes.Config) string { - var sb strings.Builder - if wgCfg.PrivateKey != nil { - hexKey := hex.EncodeToString(wgCfg.PrivateKey[:]) - sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey)) - } - - if wgCfg.ListenPort != nil { - sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort)) - } - - if wgCfg.ReplacePeers { - sb.WriteString("replace_peers=true\n") - } - - if wgCfg.FirewallMark != nil { - sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark)) - } - - for _, p := range wgCfg.Peers { - hexKey := hex.EncodeToString(p.PublicKey[:]) - sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey)) - - if p.PresharedKey != nil { - preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) - sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) - } - - if p.Remove { - sb.WriteString("remove=true") - } - - if p.ReplaceAllowedIPs { - sb.WriteString("replace_allowed_ips=true\n") - } - - for _, aip := range p.AllowedIPs { - sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) - } - - if p.Endpoint != nil { - sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String())) - } - - if p.PersistentKeepaliveInterval != nil { - sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))) - } - } - return sb.String() -} diff --git a/iface/netstack/dialer.go b/iface/netstack/dialer.go new file mode 100644 index 000000000..a9ed6d45b --- /dev/null +++ b/iface/netstack/dialer.go @@ -0,0 +1,32 @@ +package netstack + +import ( + "context" + "net" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +type Dialer interface { + Dial(ctx context.Context, network, addr string) (net.Conn, error) +} + +type NSDialer struct { + net *netstack.Net +} + +func NewNSDialer(net *netstack.Net) *NSDialer { + return &NSDialer{ + net: net, + } +} + +func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + log.Debugf("dialing %s %s", network, addr) + conn, err := d.net.Dial(network, addr) + if err != nil { + log.Debugf("failed to deal connection: %s", err) + } + return conn, err +} diff --git a/iface/netstack/env.go b/iface/netstack/env.go new file mode 100644 index 000000000..c77e39fe0 --- /dev/null +++ b/iface/netstack/env.go @@ -0,0 +1,33 @@ +package netstack + +import ( + "fmt" + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +// IsEnabled todo: move these function to cmd layer +func IsEnabled() bool { + return os.Getenv("NB_USE_NETSTACK_MODE") == "true" +} + +func ListenAddr() string { + sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT") + port, err := strconv.Atoi(sPort) + if err != nil { + log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port) + return listenAddr(DefaultSocks5Port) + } + if port < 1 || port > 65535 { + log.Warnf("invalid socks5 listener port, it should be in the range 1-65535, falling back to default: %d", DefaultSocks5Port) + return listenAddr(DefaultSocks5Port) + } + + return listenAddr(port) +} + +func listenAddr(port int) string { + return fmt.Sprintf("0.0.0.0:%d", port) +} diff --git a/iface/netstack/proxy.go b/iface/netstack/proxy.go new file mode 100644 index 000000000..a2120c642 --- /dev/null +++ b/iface/netstack/proxy.go @@ -0,0 +1,65 @@ +package netstack + +import ( + "net" + + "github.com/things-go/go-socks5" + + log "github.com/sirupsen/logrus" +) + +const ( + DefaultSocks5Port = 1080 +) + +// Proxy todo close server +type Proxy struct { + server *socks5.Server + + listener net.Listener + closed bool +} + +func NewSocks5(dialer Dialer) (*Proxy, error) { + server := socks5.NewServer( + socks5.WithDial(dialer.Dial), + ) + + return &Proxy{ + server: server, + }, nil +} + +func (s *Proxy) ListenAndServe(addr string) error { + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Errorf("failed to create listener for socks5 proxy: %s", err) + return err + } + s.listener = listener + + for { + conn, err := listener.Accept() + if err != nil { + if s.closed { + return nil + } + return err + } + + go func() { + if err := s.server.ServeConn(conn); err != nil { + log.Errorf("failed to serve a connection: %s", err) + } + }() + } +} + +func (s *Proxy) Close() error { + if s.listener == nil { + return nil + } + + s.closed = true + return s.listener.Close() +} diff --git a/iface/netstack/tun.go b/iface/netstack/tun.go new file mode 100644 index 000000000..8c7c3a8ff --- /dev/null +++ b/iface/netstack/tun.go @@ -0,0 +1,74 @@ +package netstack + +import ( + "net/netip" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +type NetStackTun struct { + address string + mtu int + listenAddress string + + proxy *Proxy + tundev tun.Device +} + +func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun { + return &NetStackTun{ + address: address, + mtu: mtu, + listenAddress: listenAddress, + } +} + +func (t *NetStackTun) Create() (tun.Device, error) { + nsTunDev, tunNet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr(t.address)}, + []netip.Addr{}, + t.mtu) + if err != nil { + return nil, err + } + t.tundev = nsTunDev + + dialer := NewNSDialer(tunNet) + t.proxy, err = NewSocks5(dialer) + if err != nil { + _ = t.tundev.Close() + return nil, err + } + + go func() { + err := t.proxy.ListenAndServe(t.listenAddress) + if err != nil { + log.Errorf("error in socks5 proxy serving: %s", err) + } + }() + + return nsTunDev, nil +} + +func (t *NetStackTun) Close() error { + var err error + if t.proxy != nil { + pErr := t.proxy.Close() + if pErr != nil { + log.Errorf("failed to close socks5 proxy: %s", pErr) + err = pErr + } + } + + if t.tundev != nil { + dErr := t.tundev.Close() + if dErr != nil { + log.Errorf("failed to close netstack tun device: %s", dErr) + err = dErr + } + } + + return err +} diff --git a/iface/tun.go b/iface/tun.go index ec8af4c32..b3c0f9d80 100644 --- a/iface/tun.go +++ b/iface/tun.go @@ -1,12 +1,18 @@ +//go:build !android +// +build !android + package iface -type MobileIFaceArguments struct { - Routes []string - Dns string - SearchDomains []string -} +import ( + "github.com/netbirdio/netbird/iface/bind" +) -// NetInterface represents a generic network tunnel interface -type NetInterface interface { +type wgTunDevice interface { + Create() (wgConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string Close() error + Wrapper() *DeviceWrapper // todo eliminate this function } diff --git a/iface/tun_android.go b/iface/tun_android.go index 3600001ba..834b2cb42 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -15,42 +15,50 @@ import ( "github.com/netbirdio/netbird/iface/bind" ) -type tunDevice struct { +// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform +type wgTunDevice struct { address WGAddress + port int + key string mtu int - tunAdapter TunAdapter iceBind *bind.ICEBind + tunAdapter TunAdapter - fd int - name string - device *device.Device - wrapper *DeviceWrapper + name string + device *device.Device + wrapper *DeviceWrapper + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer } -func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { - return &tunDevice{ +func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter) wgTunDevice { + return wgTunDevice{ address: address, + port: port, + key: key, mtu: mtu, - tunAdapter: tunAdapter, iceBind: bind.NewICEBind(transportNet), + tunAdapter: tunAdapter, } } -func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { +func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { log.Info("create tun interface") - var err error - routesString := t.routesToString(mIFaceArgs.Routes) - searchDomainsToString := t.searchDomainsToString(mIFaceArgs.SearchDomains) - t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, searchDomainsToString, routesString) + + routesString := routesToString(routes) + searchDomainsToString := searchDomainsToString(searchDomains) + + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) - return err + return nil, err } - tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(t.fd) + tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd) if err != nil { - unix.Close(t.fd) - return err + _ = unix.Close(fd) + log.Errorf("failed to create Android interface: %s", err) + return nil, err } t.name = name t.wrapper = newDeviceWrapper(tunDevice) @@ -61,44 +69,72 @@ func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - err = t.device.Up() + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) if err != nil { t.device.Close() - return err + t.configurer.close() + return nil, err } - log.Debugf("device is ready to use: %s", name) - return nil + return t.configurer, nil +} +func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + err := t.device.Up() + if err != nil { + return nil, err + } + + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err + } + t.udpMux = udpMux + log.Debugf("device is ready to use: %s", t.name) + return udpMux, nil } -func (t *tunDevice) Device() *device.Device { - return t.device -} - -func (t *tunDevice) DeviceName() string { - return t.name -} - -func (t *tunDevice) WgAddress() WGAddress { - return t.address -} - -func (t *tunDevice) UpdateAddr(addr WGAddress) error { +func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *tunDevice) Close() (err error) { - if t.device != nil { - t.device.Close() +func (t *wgTunDevice) Close() error { + if t.configurer != nil { + t.configurer.close() } - return + if t.device != nil { + t.device.Close() + t.device = nil + } + + if t.udpMux != nil { + return t.udpMux.Close() + + } + return nil } -func (t *tunDevice) routesToString(routes []string) string { +func (t *wgTunDevice) Device() *device.Device { + return t.device +} + +func (t *wgTunDevice) DeviceName() string { + return t.name +} + +func (t *wgTunDevice) WgAddress() WGAddress { + return t.address +} + +func (t *wgTunDevice) Wrapper() *DeviceWrapper { + return t.wrapper +} + +func routesToString(routes []string) string { return strings.Join(routes, ";") } -func (t *tunDevice) searchDomainsToString(searchDomains []string) string { +func searchDomainsToString(searchDomains []string) string { return strings.Join(searchDomains, ";") } diff --git a/iface/tun_args.go b/iface/tun_args.go new file mode 100644 index 000000000..0eac2c4c0 --- /dev/null +++ b/iface/tun_args.go @@ -0,0 +1,6 @@ +package iface + +type MobileIFaceArguments struct { + TunAdapter TunAdapter // only for Android + TunFd int // only for iOS +} diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index 6e917e374..bac14986f 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -6,32 +6,129 @@ package iface import ( "os/exec" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + + "github.com/netbirdio/netbird/iface/bind" ) -func (c *tunDevice) Create() error { - var err error - c.netInterface, err = c.createWithUserspace() +type tunDevice struct { + name string + address WGAddress + port int + key string + mtu int + iceBind *bind.ICEBind + + device *device.Device + wrapper *DeviceWrapper + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer +} + +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { + return &tunDevice{ + name: name, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: bind.NewICEBind(transportNet), + } +} + +func (t *tunDevice) Create() (wgConfigurer, error) { + tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { - return err + return nil, err + } + t.wrapper = newDeviceWrapper(tunDevice) + + // We need to create a wireguard-go device and listen to configuration requests + t.device = device.NewDevice( + t.wrapper, + t.iceBind, + device.NewLogger(device.LogLevelSilent, "[netbird] "), + ) + + err = t.assignAddr() + if err != nil { + t.device.Close() + return nil, err } - return c.assignAddr() + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) + if err != nil { + t.device.Close() + t.configurer.close() + return nil, err + } + return t.configurer, nil +} + +func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + err := t.device.Up() + if err != nil { + return nil, err + } + + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err + } + t.udpMux = udpMux + log.Debugf("device is ready to use: %s", t.name) + return udpMux, nil +} + +func (t *tunDevice) UpdateAddr(address WGAddress) error { + t.address = address + return t.assignAddr() +} + +func (t *tunDevice) Close() error { + if t.configurer != nil { + t.configurer.close() + } + + if t.device != nil { + t.device.Close() + t.device = nil + } + + if t.udpMux != nil { + return t.udpMux.Close() + } + return nil +} + +func (t *tunDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunDevice) DeviceName() string { + return t.name +} + +func (t *tunDevice) Wrapper() *DeviceWrapper { + return t.wrapper } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (c *tunDevice) assignAddr() error { - cmd := exec.Command("ifconfig", c.name, "inet", c.address.IP.String(), c.address.IP.String()) +func (t *tunDevice) assignAddr() error { + cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Infof(`adding address command "%v" failed with output %s and error: `, cmd.String(), out) return err } - routeCmd := exec.Command("route", "add", "-net", c.address.Network.String(), "-interface", c.name) + routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name) if out, err := routeCmd.CombinedOutput(); err != nil { log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out) return err } - return nil } diff --git a/iface/tun_ios.go b/iface/tun_ios.go index 7a9ce5622..ea980818d 100644 --- a/iface/tun_ios.go +++ b/iface/tun_ios.go @@ -6,7 +6,7 @@ package iface import ( "os" - "github.com/pion/transport/v2" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" @@ -16,63 +16,82 @@ import ( ) type tunDevice struct { - address WGAddress - mtu int - tunAdapter TunAdapter - iceBind *bind.ICEBind - - fd int name string - device *device.Device - wrapper *DeviceWrapper + address WGAddress + port int + key string + iceBind *bind.ICEBind + tunFd int + + device *device.Device + wrapper *DeviceWrapper + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int) *tunDevice { return &tunDevice{ - name: name, - address: address, - mtu: mtu, - tunAdapter: tunAdapter, - iceBind: bind.NewICEBind(transportNet), + name: name, + address: address, + port: port, + key: key, + iceBind: bind.NewICEBind(transportNet), + tunFd: tunFd, } } -func (t *tunDevice) Create(tunFd int32) error { +func (t *tunDevice) Create() (wgConfigurer, error) { log.Infof("create tun interface") - dupTunFd, err := unix.Dup(int(tunFd)) + dupTunFd, err := unix.Dup(t.tunFd) if err != nil { log.Errorf("Unable to dup tun fd: %v", err) - return err + return nil, err } err = unix.SetNonblock(dupTunFd, true) if err != nil { log.Errorf("Unable to set tun fd as non blocking: %v", err) - unix.Close(dupTunFd) - return err + _ = unix.Close(dupTunFd) + return nil, err } - tun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) + tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) if err != nil { log.Errorf("Unable to create new tun device from fd: %v", err) - unix.Close(dupTunFd) - return err + _ = unix.Close(dupTunFd) + return nil, err } - t.wrapper = newDeviceWrapper(tun) + t.wrapper = newDeviceWrapper(tunDevice) log.Debug("Attaching to interface") t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - err = t.device.Up() + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) if err != nil { t.device.Close() - return err + t.configurer.close() + return nil, err } + return t.configurer, nil +} + +func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + err := t.device.Up() + if err != nil { + return nil, err + } + + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err + } + t.udpMux = udpMux log.Debugf("device is ready to use: %s", t.name) - return nil + return udpMux, nil } func (t *tunDevice) Device() *device.Device { @@ -83,6 +102,23 @@ func (t *tunDevice) DeviceName() string { return t.name } +func (t *tunDevice) Close() error { + if t.configurer != nil { + t.configurer.close() + } + + if t.device != nil { + t.device.Close() + t.device = nil + } + + if t.udpMux != nil { + return t.udpMux.Close() + + } + return nil +} + func (t *tunDevice) WgAddress() WGAddress { return t.address } @@ -92,10 +128,6 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error { return nil } -func (t *tunDevice) Close() (err error) { - if t.device != nil { - t.device.Close() - } - - return +func (t *tunDevice) Wrapper() *DeviceWrapper { + return t.wrapper } diff --git a/iface/tun_kernel_linux.go b/iface/tun_kernel_linux.go new file mode 100644 index 000000000..12adcdf73 --- /dev/null +++ b/iface/tun_kernel_linux.go @@ -0,0 +1,209 @@ +//go:build linux && !android + +package iface + +import ( + "context" + "fmt" + "net" + "os" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/sharedsock" +) + +type tunKernelDevice struct { + name string + address WGAddress + wgPort int + key string + mtu int + ctx context.Context + ctxCancel context.CancelFunc + transportNet transport.Net + + link *wgLink + udpMuxConn net.PacketConn + udpMux *bind.UniversalUDPMuxDefault +} + +func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { + ctx, cancel := context.WithCancel(context.Background()) + return &tunKernelDevice{ + ctx: ctx, + ctxCancel: cancel, + name: name, + address: address, + wgPort: wgPort, + key: key, + mtu: mtu, + transportNet: transportNet, + } +} + +func (t *tunKernelDevice) Create() (wgConfigurer, error) { + link := newWGLink(t.name) + + // check if interface exists + l, err := netlink.LinkByName(t.name) + if err != nil { + switch err.(type) { + case netlink.LinkNotFoundError: + break + default: + return nil, err + } + } + + // remove if interface exists + if l != nil { + err = netlink.LinkDel(link) + if err != nil { + return nil, err + } + } + + log.Debugf("adding device: %s", t.name) + err = netlink.LinkAdd(link) + if os.IsExist(err) { + log.Infof("interface %s already exists. Will reuse.", t.name) + } else if err != nil { + return nil, err + } + + t.link = link + + err = t.assignAddr() + if err != nil { + return nil, err + } + + // todo do a discovery + log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) + err = netlink.LinkSetMTU(link, t.mtu) + if err != nil { + log.Errorf("error setting MTU on interface: %s", t.name) + return nil, err + } + + configurer := newWGConfigurer(t.name) + err = configurer.configureInterface(t.key, t.wgPort) + if err != nil { + return nil, err + } + return configurer, nil +} + +func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + if t.udpMux != nil { + return t.udpMux, nil + } + + if t.link == nil { + return nil, fmt.Errorf("device is not ready yet") + } + + log.Debugf("bringing up interface: %s", t.name) + err := netlink.LinkSetUp(t.link) + if err != nil { + log.Errorf("error bringing up interface: %s", t.name) + return nil, err + } + + rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter()) + if err != nil { + return nil, err + } + bindParams := bind.UniversalUDPMuxParams{ + UDPConn: rawSock, + Net: t.transportNet, + } + mux := bind.NewUniversalUDPMuxDefault(bindParams) + go mux.ReadFromConn(t.ctx) + t.udpMuxConn = rawSock + t.udpMux = mux + + log.Debugf("device is ready to use: %s", t.name) + return t.udpMux, nil +} + +func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { + t.address = address + return t.assignAddr() +} + +func (t *tunKernelDevice) Close() error { + if t.link == nil { + return nil + } + + t.ctxCancel() + + var closErr error + if err := t.link.Close(); err != nil { + log.Debugf("failed to close link: %s", err) + closErr = err + } + + if t.udpMux != nil { + if err := t.udpMux.Close(); err != nil { + log.Debugf("failed to close udp mux: %s", err) + closErr = err + } + + if err := t.udpMuxConn.Close(); err != nil { + log.Debugf("failed to close udp mux connection: %s", err) + closErr = err + } + } + + return closErr +} + +func (t *tunKernelDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunKernelDevice) DeviceName() string { + return t.name +} + +func (t *tunKernelDevice) Wrapper() *DeviceWrapper { + return nil +} + +// assignAddr Adds IP address to the tunnel interface +func (t *tunKernelDevice) assignAddr() error { + link := newWGLink(t.name) + + //delete existing addresses + list, err := netlink.AddrList(link, 0) + if err != nil { + return err + } + if len(list) > 0 { + for _, a := range list { + addr := a + err = netlink.AddrDel(link, &addr) + if err != nil { + return err + } + } + } + + log.Debugf("adding address %s to interface: %s", t.address.String(), t.name) + addr, _ := netlink.ParseAddr(t.address.String()) + err = netlink.AddrAdd(link, addr) + if os.IsExist(err) { + log.Infof("interface %s already has the address: %s", t.name, t.address.String()) + } else if err != nil { + return err + } + // On linux, the link must be brought up + err = netlink.LinkSetUp(link) + return err +} diff --git a/iface/tun_link_linux.go b/iface/tun_link_linux.go new file mode 100644 index 000000000..ab28b7e38 --- /dev/null +++ b/iface/tun_link_linux.go @@ -0,0 +1,33 @@ +//go:build linux && !android + +package iface + +import "github.com/vishvananda/netlink" + +type wgLink struct { + attrs *netlink.LinkAttrs +} + +func newWGLink(name string) *wgLink { + attrs := netlink.NewLinkAttrs() + attrs.Name = name + + return &wgLink{ + attrs: &attrs, + } +} + +// Attrs returns the Wireguard's default attributes +func (l *wgLink) Attrs() *netlink.LinkAttrs { + return l.attrs +} + +// Type returns the interface type +func (l *wgLink) Type() string { + return "wireguard" +} + +// Close deletes the link interface +func (l *wgLink) Close() error { + return netlink.LinkDel(l) +} diff --git a/iface/tun_linux.go b/iface/tun_linux.go deleted file mode 100644 index 1a3537394..000000000 --- a/iface/tun_linux.go +++ /dev/null @@ -1,149 +0,0 @@ -//go:build linux && !android - -package iface - -import ( - "fmt" - "os" - - log "github.com/sirupsen/logrus" - "github.com/vishvananda/netlink" -) - -func (c *tunDevice) Create() error { - if WireGuardModuleIsLoaded() { - log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName()) - return c.createWithKernel() - } - - if !tunModuleIsLoaded() { - return fmt.Errorf("couldn't check or load tun module") - } - log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName()) - var err error - c.netInterface, err = c.createWithUserspace() - if err != nil { - return err - } - - return c.assignAddr() - -} - -// createWithKernel Creates a new WireGuard interface using kernel WireGuard module. -// Works for Linux and offers much better network performance -func (c *tunDevice) createWithKernel() error { - - link := newWGLink(c.name) - - // check if interface exists - l, err := netlink.LinkByName(c.name) - if err != nil { - switch err.(type) { - case netlink.LinkNotFoundError: - break - default: - return err - } - } - - // remove if interface exists - if l != nil { - err = netlink.LinkDel(link) - if err != nil { - return err - } - } - - log.Debugf("adding device: %s", c.name) - err = netlink.LinkAdd(link) - if os.IsExist(err) { - log.Infof("interface %s already exists. Will reuse.", c.name) - } else if err != nil { - return err - } - - c.netInterface = link - - err = c.assignAddr() - if err != nil { - return err - } - - // todo do a discovery - log.Debugf("setting MTU: %d interface: %s", c.mtu, c.name) - err = netlink.LinkSetMTU(link, c.mtu) - if err != nil { - log.Errorf("error setting MTU on interface: %s", c.name) - return err - } - - log.Debugf("bringing up interface: %s", c.name) - err = netlink.LinkSetUp(link) - if err != nil { - log.Errorf("error bringing up interface: %s", c.name) - return err - } - - return nil -} - -// assignAddr Adds IP address to the tunnel interface -func (c *tunDevice) assignAddr() error { - link := newWGLink(c.name) - - //delete existing addresses - list, err := netlink.AddrList(link, 0) - if err != nil { - return err - } - if len(list) > 0 { - for _, a := range list { - addr := a - err = netlink.AddrDel(link, &addr) - if err != nil { - return err - } - } - } - - log.Debugf("adding address %s to interface: %s", c.address.String(), c.name) - addr, _ := netlink.ParseAddr(c.address.String()) - err = netlink.AddrAdd(link, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", c.name, c.address.String()) - } else if err != nil { - return err - } - // On linux, the link must be brought up - err = netlink.LinkSetUp(link) - return err -} - -type wgLink struct { - attrs *netlink.LinkAttrs -} - -func newWGLink(name string) *wgLink { - attrs := netlink.NewLinkAttrs() - attrs.Name = name - - return &wgLink{ - attrs: &attrs, - } -} - -// Attrs returns the Wireguard's default attributes -func (l *wgLink) Attrs() *netlink.LinkAttrs { - return l.attrs -} - -// Type returns the interface type -func (l *wgLink) Type() string { - return "wireguard" -} - -// Close deletes the link interface -func (l *wgLink) Close() error { - return netlink.LinkDel(l) -} diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go new file mode 100644 index 000000000..e1d01ecc9 --- /dev/null +++ b/iface/tun_netstack.go @@ -0,0 +1,119 @@ +//go:build !android +// +build !android + +package iface + +import ( + "fmt" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" + + "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/iface/netstack" +) + +type tunNetstackDevice struct { + name string + address WGAddress + port int + key string + mtu int + listenAddress string + iceBind *bind.ICEBind + + device *device.Device + wrapper *DeviceWrapper + nsTun *netstack.NetStackTun + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer +} + +func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { + return &tunNetstackDevice{ + name: name, + address: address, + port: wgPort, + key: key, + mtu: mtu, + listenAddress: listenAddress, + iceBind: bind.NewICEBind(transportNet), + } +} + +func (t *tunNetstackDevice) Create() (wgConfigurer, error) { + log.Info("create netstack tun interface") + t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) + tunIface, err := t.nsTun.Create() + if err != nil { + return nil, err + } + t.wrapper = newDeviceWrapper(tunIface) + + t.device = device.NewDevice( + t.wrapper, + t.iceBind, + device.NewLogger(device.LogLevelSilent, "[netbird] "), + ) + + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) + if err != nil { + _ = tunIface.Close() + return nil, err + } + + log.Debugf("device has been created: %s", t.name) + return t.configurer, nil +} + +func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + if t.device == nil { + return nil, fmt.Errorf("device is not ready yet") + } + + err := t.device.Up() + if err != nil { + return nil, err + } + + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err + } + t.udpMux = udpMux + log.Debugf("netstack device is ready to use") + return udpMux, nil +} + +func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { + return nil +} + +func (t *tunNetstackDevice) Close() error { + if t.configurer != nil { + t.configurer.close() + } + + if t.device != nil { + t.device.Close() + } + + if t.udpMux != nil { + return t.udpMux.Close() + } + return nil +} + +func (t *tunNetstackDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunNetstackDevice) DeviceName() string { + return t.name +} + +func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { + return t.wrapper +} diff --git a/iface/tun_unix.go b/iface/tun_unix.go deleted file mode 100644 index bc2d8d019..000000000 --- a/iface/tun_unix.go +++ /dev/null @@ -1,145 +0,0 @@ -//go:build (linux || darwin) && !android && !ios - -package iface - -import ( - "net" - "os" - - "github.com/pion/transport/v3" - "golang.zx2c4.com/wireguard/ipc" - - "github.com/netbirdio/netbird/iface/bind" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" -) - -type tunDevice struct { - name string - address WGAddress - mtu int - netInterface NetInterface - iceBind *bind.ICEBind - uapi net.Listener - wrapper *DeviceWrapper - close chan struct{} -} - -func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice { - return &tunDevice{ - name: name, - address: address, - mtu: mtu, - iceBind: bind.NewICEBind(transportNet), - close: make(chan struct{}), - } -} - -func (c *tunDevice) UpdateAddr(address WGAddress) error { - c.address = address - return c.assignAddr() -} - -func (c *tunDevice) WgAddress() WGAddress { - return c.address -} - -func (c *tunDevice) DeviceName() string { - return c.name -} - -func (c *tunDevice) Close() error { - - select { - case c.close <- struct{}{}: - default: - } - - var err1, err2, err3 error - if c.netInterface != nil { - err1 = c.netInterface.Close() - } - - if c.uapi != nil { - err2 = c.uapi.Close() - } - - sockPath := "/var/run/wireguard/" + c.name + ".sock" - if _, statErr := os.Stat(sockPath); statErr == nil { - statErr = os.Remove(sockPath) - if statErr != nil { - err3 = statErr - } - } - - if err1 != nil { - return err1 - } - - if err2 != nil { - return err2 - } - - return err3 -} - -// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation -func (c *tunDevice) createWithUserspace() (NetInterface, error) { - tunIface, err := tun.CreateTUN(c.name, c.mtu) - if err != nil { - return nil, err - } - c.wrapper = newDeviceWrapper(tunIface) - - // We need to create a wireguard-go device and listen to configuration requests - tunDev := device.NewDevice( - c.wrapper, - c.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), - ) - err = tunDev.Up() - if err != nil { - _ = tunIface.Close() - return nil, err - } - - c.uapi, err = c.getUAPI(c.name) - if err != nil { - _ = tunIface.Close() - return nil, err - } - - go func() { - for { - select { - case <-c.close: - log.Debugf("exit uapi.Accept()") - return - default: - } - uapiConn, uapiErr := c.uapi.Accept() - if uapiErr != nil { - log.Traceln("uapi Accept failed with error: ", uapiErr) - continue - } - go func() { - tunDev.IpcHandle(uapiConn) - log.Debugf("exit tunDevice.IpcHandle") - }() - } - }() - - log.Debugln("UAPI listener started") - return tunIface, nil -} - -// getUAPI returns a Listener -func (c *tunDevice) getUAPI(iface string) (net.Listener, error) { - tunSock, err := ipc.UAPIOpen(iface) - if err != nil { - return nil, err - } - return ipc.UAPIListen(iface, tunSock) -} diff --git a/iface/tun_usp_linux.go b/iface/tun_usp_linux.go new file mode 100644 index 000000000..3ed518d52 --- /dev/null +++ b/iface/tun_usp_linux.go @@ -0,0 +1,157 @@ +//go:build linux && !android + +package iface + +import ( + "fmt" + "os" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + + "github.com/netbirdio/netbird/iface/bind" +) + +type tunUSPDevice struct { + name string + address WGAddress + port int + key string + mtu int + iceBind *bind.ICEBind + + device *device.Device + wrapper *DeviceWrapper + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer +} + +func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { + log.Infof("using userspace bind mode") + return &tunUSPDevice{ + name: name, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: bind.NewICEBind(transportNet), + } +} + +func (t *tunUSPDevice) Create() (wgConfigurer, error) { + log.Info("create tun interface") + tunIface, err := tun.CreateTUN(t.name, t.mtu) + if err != nil { + return nil, err + } + t.wrapper = newDeviceWrapper(tunIface) + + // We need to create a wireguard-go device and listen to configuration requests + t.device = device.NewDevice( + t.wrapper, + t.iceBind, + device.NewLogger(device.LogLevelSilent, "[netbird] "), + ) + + err = t.assignAddr() + if err != nil { + t.device.Close() + return nil, err + } + + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) + if err != nil { + t.device.Close() + t.configurer.close() + return nil, err + } + return t.configurer, nil +} + +func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + if t.device == nil { + return nil, fmt.Errorf("device is not ready yet") + } + + err := t.device.Up() + if err != nil { + return nil, err + } + + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err + } + t.udpMux = udpMux + + log.Debugf("device is ready to use: %s", t.name) + return udpMux, nil +} + +func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { + t.address = address + return t.assignAddr() +} + +func (t *tunUSPDevice) Close() error { + if t.configurer != nil { + t.configurer.close() + } + + if t.device != nil { + t.device.Close() + } + + if t.udpMux != nil { + return t.udpMux.Close() + } + return nil +} + +func (t *tunUSPDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunUSPDevice) DeviceName() string { + return t.name +} + +func (t *tunUSPDevice) Wrapper() *DeviceWrapper { + return t.wrapper +} + +// assignAddr Adds IP address to the tunnel interface +func (t *tunUSPDevice) assignAddr() error { + link := newWGLink(t.name) + + //delete existing addresses + list, err := netlink.AddrList(link, 0) + if err != nil { + return err + } + if len(list) > 0 { + for _, a := range list { + addr := a + err = netlink.AddrDel(link, &addr) + if err != nil { + return err + } + } + } + + log.Debugf("adding address %s to interface: %s", t.address.String(), t.name) + addr, _ := netlink.ParseAddr(t.address.String()) + err = netlink.AddrAdd(link, addr) + if os.IsExist(err) { + log.Infof("interface %s already has the address: %s", t.name, t.address.String()) + } else if err != nil { + return err + } + // On linux, the link must be brought up + err = netlink.LinkSetUp(link) + return err +} diff --git a/iface/tun_windows.go b/iface/tun_windows.go index a4ddf1d85..900e62fc3 100644 --- a/iface/tun_windows.go +++ b/iface/tun_windows.go @@ -2,14 +2,12 @@ package iface import ( "fmt" - "net" "net/netip" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -17,139 +15,131 @@ import ( ) type tunDevice struct { - name string - address WGAddress - netInterface NetInterface - iceBind *bind.ICEBind - mtu int - uapi net.Listener - wrapper *DeviceWrapper - close chan struct{} + name string + address WGAddress + port int + key string + mtu int + iceBind *bind.ICEBind + + device *device.Device + nativeTunDevice *tun.NativeTun + wrapper *DeviceWrapper + udpMux *bind.UniversalUDPMuxDefault + configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { return &tunDevice{ name: name, address: address, + port: port, + key: key, mtu: mtu, iceBind: bind.NewICEBind(transportNet), - close: make(chan struct{}), } } -func (c *tunDevice) Create() error { - var err error - c.netInterface, err = c.createWithUserspace() - if err != nil { - return err - } - - return c.assignAddr() -} - -// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation -func (c *tunDevice) createWithUserspace() (NetInterface, error) { - tunIface, err := tun.CreateTUN(c.name, c.mtu) +func (t *tunDevice) Create() (wgConfigurer, error) { + tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, err } - c.wrapper = newDeviceWrapper(tunIface) + t.nativeTunDevice = tunDevice.(*tun.NativeTun) + t.wrapper = newDeviceWrapper(tunDevice) // We need to create a wireguard-go device and listen to configuration requests - tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] ")) - err = tunDev.Up() - if err != nil { - _ = tunIface.Close() - return nil, err - } + t.device = device.NewDevice( + t.wrapper, + t.iceBind, + device.NewLogger(device.LogLevelSilent, "[netbird] "), + ) - luid := winipcfg.LUID(tunIface.(*tun.NativeTun).LUID()) + luid := winipcfg.LUID(t.nativeTunDevice.LUID()) nbiface, err := luid.IPInterface(windows.AF_INET) if err != nil { - _ = tunIface.Close() + t.device.Close() return nil, fmt.Errorf("got error when getting ip interface %s", err) } - nbiface.NLMTU = uint32(c.mtu) + nbiface.NLMTU = uint32(t.mtu) err = nbiface.Set() if err != nil { - _ = tunIface.Close() + t.device.Close() return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err) } - - c.uapi, err = c.getUAPI(c.name) + err = t.assignAddr() if err != nil { - _ = tunIface.Close() + t.device.Close() return nil, err } - go func() { - for { - select { - case <-c.close: - log.Debugf("exit uapi.Accept()") - return - default: - } - uapiConn, uapiErr := c.uapi.Accept() - if uapiErr != nil { - log.Traceln("uapi Accept failed with error: ", uapiErr) - continue - } - go func() { - tunDev.IpcHandle(uapiConn) - log.Debugf("exit tunDevice.IpcHandle") - }() - } - }() - - log.Debugln("UAPI listener started") - return tunIface, nil + t.configurer = newWGUSPConfigurer(t.device, t.name) + err = t.configurer.configureInterface(t.key, t.port) + if err != nil { + t.device.Close() + t.configurer.close() + return nil, err + } + return t.configurer, nil } -func (c *tunDevice) UpdateAddr(address WGAddress) error { - c.address = address - return c.assignAddr() -} - -func (c *tunDevice) WgAddress() WGAddress { - return c.address -} - -func (c *tunDevice) DeviceName() string { - return c.name -} - -func (c *tunDevice) Close() error { - select { - case c.close <- struct{}{}: - default: +func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { + err := t.device.Up() + if err != nil { + return nil, err } - var err1, err2 error - if c.netInterface != nil { - err1 = c.netInterface.Close() + udpMux, err := t.iceBind.GetICEMux() + if err != nil { + return nil, err } - - if c.uapi != nil { - err2 = c.uapi.Close() - } - - if err1 != nil { - return err1 - } - - return err2 + t.udpMux = udpMux + log.Debugf("device is ready to use: %s", t.name) + return udpMux, nil } -func (c *tunDevice) getInterfaceGUIDString() (string, error) { - if c.netInterface == nil { +func (t *tunDevice) UpdateAddr(address WGAddress) error { + t.address = address + return t.assignAddr() +} + +func (t *tunDevice) Close() error { + if t.configurer != nil { + t.configurer.close() + } + + if t.device != nil { + t.device.Close() + t.device = nil + } + + if t.udpMux != nil { + return t.udpMux.Close() + + } + return nil +} +func (t *tunDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunDevice) DeviceName() string { + return t.name +} + +func (t *tunDevice) Wrapper() *DeviceWrapper { + return t.wrapper +} + +func (t *tunDevice) getInterfaceGUIDString() (string, error) { + if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") } - windowsDevice := c.netInterface.(*tun.NativeTun) - luid := winipcfg.LUID(windowsDevice.LUID()) + + luid := winipcfg.LUID(t.nativeTunDevice.LUID()) guid, err := luid.GUID() if err != nil { return "", err @@ -158,14 +148,8 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) { } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (c *tunDevice) assignAddr() error { - tunDev := c.netInterface.(*tun.NativeTun) - luid := winipcfg.LUID(tunDev.LUID()) - log.Debugf("adding address %s to interface: %s", c.address.IP, c.name) - return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(c.address.String())}) -} - -// getUAPI returns a Listener -func (c *tunDevice) getUAPI(iface string) (net.Listener, error) { - return ipc.UAPIListen(iface) +func (t *tunDevice) assignAddr() error { + luid := winipcfg.LUID(t.nativeTunDevice.LUID()) + log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) + return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) } diff --git a/iface/uapi.go b/iface/uapi.go new file mode 100644 index 000000000..d7ff52e7b --- /dev/null +++ b/iface/uapi.go @@ -0,0 +1,26 @@ +//go:build !windows + +package iface + +import ( + "net" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/ipc" +) + +func openUAPI(deviceName string) (net.Listener, error) { + uapiSock, err := ipc.UAPIOpen(deviceName) + if err != nil { + log.Errorf("failed to open uapi socket: %v", err) + return nil, err + } + + listener, err := ipc.UAPIListen(deviceName, uapiSock) + if err != nil { + log.Errorf("failed to listen on uapi socket: %v", err) + return nil, err + } + + return listener, nil +} diff --git a/iface/uapi_windows.go b/iface/uapi_windows.go new file mode 100644 index 000000000..e1f466364 --- /dev/null +++ b/iface/uapi_windows.go @@ -0,0 +1,11 @@ +package iface + +import ( + "net" + + "golang.zx2c4.com/wireguard/ipc" +) + +func openUAPI(deviceName string) (net.Listener, error) { + return ipc.UAPIListen(deviceName) +} diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go new file mode 100644 index 000000000..b56d75084 --- /dev/null +++ b/iface/wg_configurer.go @@ -0,0 +1,17 @@ +package iface + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type wgConfigurer interface { + configureInterface(privateKey string, port int) error + updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + removePeer(peerKey string) error + addAllowedIP(peerKey string, allowedIP string) error + removeAllowedIP(peerKey string, allowedIP string) error + close() +} diff --git a/iface/wg_configurer_nonmobile.go b/iface/wg_configurer_kernel.go similarity index 83% rename from iface/wg_configurer_nonmobile.go rename to iface/wg_configurer_kernel.go index c09dda9ad..3192f5a2b 100644 --- a/iface/wg_configurer_nonmobile.go +++ b/iface/wg_configurer_kernel.go @@ -1,4 +1,4 @@ -//go:build !android && !ios +//go:build linux && !android package iface @@ -12,17 +12,18 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wGConfigurer struct { +type wgKernelConfigurer struct { deviceName string } -func newWGConfigurer(deviceName string) wGConfigurer { - return wGConfigurer{ +func newWGConfigurer(deviceName string) wgConfigurer { + wgc := &wgKernelConfigurer{ deviceName: deviceName, } + return wgc } -func (c *wGConfigurer) configureInterface(privateKey string, port int) error { +func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -43,7 +44,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error { return nil } -func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -73,7 +74,7 @@ func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive t return nil } -func (c *wGConfigurer) removePeer(peerKey string) error { +func (c *wgKernelConfigurer) removePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -94,7 +95,7 @@ func (c *wGConfigurer) removePeer(peerKey string) error { return nil } -func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -121,7 +122,7 @@ func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return nil } -func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { +func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -163,7 +164,7 @@ func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { return nil } -func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { +func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { return wgtypes.Peer{}, err @@ -187,7 +188,7 @@ func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, erro return wgtypes.Peer{}, fmt.Errorf("peer not found") } -func (c *wGConfigurer) configure(config wgtypes.Config) error { +func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { wg, err := wgctrl.New() if err != nil { return err @@ -203,3 +204,6 @@ func (c *wGConfigurer) configure(config wgtypes.Config) error { return wg.ConfigureDevice(c.deviceName, config) } + +func (c *wgKernelConfigurer) close() { +} diff --git a/iface/wg_configurer_mobile.go b/iface/wg_configurer_mobile.go deleted file mode 100644 index 7f6e5595d..000000000 --- a/iface/wg_configurer_mobile.go +++ /dev/null @@ -1,165 +0,0 @@ -//go:build ios || android -// +build ios android - -package iface - -import ( - "encoding/hex" - "errors" - "fmt" - "net" - "strings" - "time" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ( - errFuncNotImplemented = errors.New("function not implemented") -) - -type wGConfigurer struct { - tunDevice *tunDevice -} - -func newWGConfigurer(tunDevice *tunDevice) wGConfigurer { - return wGConfigurer{ - tunDevice: tunDevice, - } -} - -func (c *wGConfigurer) configureInterface(privateKey string, port int) error { - log.Debugf("adding Wireguard private key") - key, err := wgtypes.ParseKey(privateKey) - if err != nil { - return err - } - fwmark := 0 - config := wgtypes.Config{ - PrivateKey: &key, - ReplacePeers: true, - FirewallMark: &fwmark, - ListenPort: &port, - } - - return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) -} - -func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - // parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{*ipNet}, - PersistentKeepaliveInterval: &keepAlive, - PresharedKey: preSharedKey, - Endpoint: endpoint, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - - return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) -} - -func (c *wGConfigurer) removePeer(peerKey string) error { - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - Remove: true, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) -} - -func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { - _, ipNet, err := net.ParseCIDR(allowedIP) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - UpdateOnly: true, - ReplaceAllowedIPs: false, - AllowedIPs: []net.IPNet{*ipNet}, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - - return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) -} - -func (c *wGConfigurer) removeAllowedIP(peerKey string, ip string) error { - ipc, err := c.tunDevice.Device().IpcGet() - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - hexKey := hex.EncodeToString(peerKeyParsed[:]) - - lines := strings.Split(ipc, "\n") - - output := "" - foundPeer := false - removedAllowedIP := false - for _, line := range lines { - line = strings.TrimSpace(line) - - // If we're within the details of the found peer and encounter another public key, - // this means we're starting another peer's details. So, reset the flag. - if strings.HasPrefix(line, "public_key=") && foundPeer { - foundPeer = false - } - - // Identify the peer with the specific public key - if line == fmt.Sprintf("public_key=%s", hexKey) { - foundPeer = true - } - - // If we're within the details of the found peer and find the specific allowed IP, skip this line - if foundPeer && line == "allowed_ip="+ip { - removedAllowedIP = true - continue - } - - // Append the line to the output string - if strings.HasPrefix(line, "private_key=") || strings.HasPrefix(line, "listen_port=") || - strings.HasPrefix(line, "public_key=") || strings.HasPrefix(line, "preshared_key=") || - strings.HasPrefix(line, "endpoint=") || strings.HasPrefix(line, "persistent_keepalive_interval=") || - strings.HasPrefix(line, "allowed_ip=") { - output += line + "\n" - } - } - - if !removedAllowedIP { - return fmt.Errorf("allowedIP not found") - } else { - return c.tunDevice.Device().IpcSet(output) - } -} diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go new file mode 100644 index 000000000..cf12b9900 --- /dev/null +++ b/iface/wg_configurer_usp.go @@ -0,0 +1,259 @@ +package iface + +import ( + "encoding/hex" + "fmt" + "net" + "os" + "runtime" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type wgUSPConfigurer struct { + device *device.Device + deviceName string + + uapiListener net.Listener +} + +func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { + wgCfg := &wgUSPConfigurer{ + device: device, + deviceName: deviceName, + } + wgCfg.startUAPI() + return wgCfg +} + +func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { + log.Debugf("adding Wireguard private key") + key, err := wgtypes.ParseKey(privateKey) + if err != nil { + return err + } + fwmark := 0 + config := wgtypes.Config{ + PrivateKey: &key, + ReplacePeers: true, + FirewallMark: &fwmark, + ListenPort: &port, + } + + return c.device.IpcSet(toWgUserspaceString(config)) +} + +func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + // parse allowed ips + _, ipNet, err := net.ParseCIDR(allowedIps) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{*ipNet}, + PersistentKeepaliveInterval: &keepAlive, + PresharedKey: preSharedKey, + Endpoint: endpoint, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + return c.device.IpcSet(toWgUserspaceString(config)) +} + +func (c *wgUSPConfigurer) removePeer(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + return c.device.IpcSet(toWgUserspaceString(config)) +} + +func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + return c.device.IpcSet(toWgUserspaceString(config)) +} + +func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { + ipc, err := c.device.IpcGet() + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + hexKey := hex.EncodeToString(peerKeyParsed[:]) + + lines := strings.Split(ipc, "\n") + + output := "" + foundPeer := false + removedAllowedIP := false + for _, line := range lines { + line = strings.TrimSpace(line) + + // If we're within the details of the found peer and encounter another public key, + // this means we're starting another peer's details. So, reset the flag. + if strings.HasPrefix(line, "public_key=") && foundPeer { + foundPeer = false + } + + // Identify the peer with the specific public key + if line == fmt.Sprintf("public_key=%s", hexKey) { + foundPeer = true + } + + // If we're within the details of the found peer and find the specific allowed IP, skip this line + if foundPeer && line == "allowed_ip="+ip { + removedAllowedIP = true + continue + } + + // Append the line to the output string + if strings.HasPrefix(line, "private_key=") || strings.HasPrefix(line, "listen_port=") || + strings.HasPrefix(line, "public_key=") || strings.HasPrefix(line, "preshared_key=") || + strings.HasPrefix(line, "endpoint=") || strings.HasPrefix(line, "persistent_keepalive_interval=") || + strings.HasPrefix(line, "allowed_ip=") { + output += line + "\n" + } + } + + if !removedAllowedIP { + return fmt.Errorf("allowedIP not found") + } else { + return c.device.IpcSet(output) + } +} + +// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool +func (t *wgUSPConfigurer) startUAPI() { + var err error + t.uapiListener, err = openUAPI(t.deviceName) + if err != nil { + log.Errorf("failed to open uapi listener: %v", err) + return + } + + go func(uapi net.Listener) { + for { + uapiConn, uapiErr := uapi.Accept() + if uapiErr != nil { + log.Tracef("%s", uapiErr) + return + } + go func() { + t.device.IpcHandle(uapiConn) + }() + } + }(t.uapiListener) +} + +func (t *wgUSPConfigurer) close() { + if t.uapiListener != nil { + err := t.uapiListener.Close() + if err != nil { + log.Errorf("failed to close uapi listener: %v", err) + } + } + + if runtime.GOOS == "linux" { + sockPath := "/var/run/wireguard/" + t.deviceName + ".sock" + if _, statErr := os.Stat(sockPath); statErr == nil { + _ = os.Remove(sockPath) + } + } +} + +func toWgUserspaceString(wgCfg wgtypes.Config) string { + var sb strings.Builder + if wgCfg.PrivateKey != nil { + hexKey := hex.EncodeToString(wgCfg.PrivateKey[:]) + sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey)) + } + + if wgCfg.ListenPort != nil { + sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort)) + } + + if wgCfg.ReplacePeers { + sb.WriteString("replace_peers=true\n") + } + + if wgCfg.FirewallMark != nil { + sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark)) + } + + for _, p := range wgCfg.Peers { + hexKey := hex.EncodeToString(p.PublicKey[:]) + sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey)) + + if p.PresharedKey != nil { + preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) + sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) + } + + if p.Remove { + sb.WriteString("remove=true") + } + + if p.ReplaceAllowedIPs { + sb.WriteString("replace_allowed_ips=true\n") + } + + for _, aip := range p.AllowedIPs { + sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) + } + + if p.Endpoint != nil { + sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String())) + } + + if p.PersistentKeepaliveInterval != nil { + sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))) + } + } + return sb.String() +}