From 631ef4ed28a5b1fcd4dfe53d645c169deb2882b0 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 20 Feb 2025 13:22:03 +0100 Subject: [PATCH] [client] Add embeddable library (#3239) --- client/embed/doc.go | 167 ++++++++++++ client/embed/embed.go | 296 +++++++++++++++++++++ client/firewall/uspfilter/uspfilter.go | 27 +- client/iface/device.go | 3 + client/iface/device/device_android.go | 5 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_ios.go | 5 + client/iface/device/device_kernel_unix.go | 5 + client/iface/device/device_netstack.go | 24 +- client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/device_android.go | 3 + client/iface/iface.go | 9 + client/iface/iface_moc.go | 6 + client/iface/iwginterface.go | 2 + client/iface/iwginterface_windows.go | 2 + client/iface/netstack/env.go | 4 +- client/iface/netstack/tun.go | 42 ++- client/internal/dns/service_memory.go | 24 +- client/internal/dns/service_memory_test.go | 4 +- client/internal/engine.go | 36 ++- util/net/net.go | 20 ++ 22 files changed, 648 insertions(+), 51 deletions(-) create mode 100644 client/embed/doc.go create mode 100644 client/embed/embed.go diff --git a/client/embed/doc.go b/client/embed/doc.go new file mode 100644 index 000000000..069d53ebf --- /dev/null +++ b/client/embed/doc.go @@ -0,0 +1,167 @@ +// Package embed provides a way to embed the NetBird client directly +// into Go programs without requiring a separate NetBird client installation. +package embed + +// Basic Usage: +// +// client, err := embed.New(embed.Options{ +// DeviceName: "my-service", +// SetupKey: os.Getenv("NB_SETUP_KEY"), +// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"), +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +// defer cancel() +// if err := client.Start(ctx); err != nil { +// log.Fatal(err) +// } +// +// Complete HTTP Server Example: +// +// package main +// +// import ( +// "context" +// "fmt" +// "log" +// "net/http" +// "os" +// "os/signal" +// "syscall" +// "time" +// +// netbird "github.com/netbirdio/netbird/client/embed" +// ) +// +// func main() { +// // Create client with setup key and device name +// client, err := netbird.New(netbird.Options{ +// DeviceName: "http-server", +// SetupKey: os.Getenv("NB_SETUP_KEY"), +// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"), +// LogOutput: io.Discard, +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// // Start with timeout +// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +// defer cancel() +// if err := client.Start(ctx); err != nil { +// log.Fatal(err) +// } +// +// // Create HTTP server +// mux := http.NewServeMux() +// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path) +// fmt.Fprintf(w, "Hello from netbird!") +// }) +// +// // Listen on netbird network +// l, err := client.ListenTCP(":8080") +// if err != nil { +// log.Fatal(err) +// } +// +// server := &http.Server{Handler: mux} +// go func() { +// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) { +// log.Printf("HTTP server error: %v", err) +// } +// }() +// +// log.Printf("HTTP server listening on netbird network port 8080") +// +// // Handle shutdown +// stop := make(chan os.Signal, 1) +// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) +// <-stop +// +// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// defer cancel() +// +// if err := server.Shutdown(shutdownCtx); err != nil { +// log.Printf("HTTP shutdown error: %v", err) +// } +// if err := client.Stop(shutdownCtx); err != nil { +// log.Printf("Netbird shutdown error: %v", err) +// } +// } +// +// Complete HTTP Client Example: +// +// package main +// +// import ( +// "context" +// "fmt" +// "io" +// "log" +// "os" +// "time" +// +// netbird "github.com/netbirdio/netbird/client/embed" +// ) +// +// func main() { +// // Create client with setup key and device name +// client, err := netbird.New(netbird.Options{ +// DeviceName: "http-client", +// SetupKey: os.Getenv("NB_SETUP_KEY"), +// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"), +// LogOutput: io.Discard, +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// // Start with timeout +// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +// defer cancel() +// +// if err := client.Start(ctx); err != nil { +// log.Fatal(err) +// } +// +// // Create HTTP client that uses netbird network +// httpClient := client.NewHTTPClient() +// httpClient.Timeout = 10 * time.Second +// +// // Make request to server in netbird network +// target := os.Getenv("NB_TARGET") +// resp, err := httpClient.Get(target) +// if err != nil { +// log.Fatal(err) +// } +// defer resp.Body.Close() +// +// // Read and print response +// body, err := io.ReadAll(resp.Body) +// if err != nil { +// log.Fatal(err) +// } +// +// fmt.Printf("Response from server: %s\n", string(body)) +// +// // Clean shutdown +// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// defer cancel() +// +// if err := client.Stop(shutdownCtx); err != nil { +// log.Printf("Netbird shutdown error: %v", err) +// } +// } +// +// The package provides several methods for network operations: +// - Dial: Creates outbound connections +// - ListenTCP: Creates TCP listeners +// - ListenUDP: Creates UDP listeners +// +// By default, the embed package uses userspace networking mode, which doesn't +// require root/admin privileges. For production deployments, consider setting +// appropriate config and state paths for persistence. diff --git a/client/embed/embed.go b/client/embed/embed.go new file mode 100644 index 000000000..9ded618c5 --- /dev/null +++ b/client/embed/embed.go @@ -0,0 +1,296 @@ +package embed + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "os" + "sync" + + "github.com/sirupsen/logrus" + wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/system" +) + +var ErrClientAlreadyStarted = errors.New("client already started") +var ErrClientNotStarted = errors.New("client not started") + +// Client manages a netbird embedded client instance +type Client struct { + deviceName string + config *internal.Config + mu sync.Mutex + cancel context.CancelFunc + setupKey string + connect *internal.ConnectClient +} + +// Options configures a new Client +type Options struct { + // DeviceName is this peer's name in the network + DeviceName string + // SetupKey is used for authentication + SetupKey string + // ManagementURL overrides the default management server URL + ManagementURL string + // PreSharedKey is the pre-shared key for the WireGuard interface + PreSharedKey string + // LogOutput is the output destination for logs (defaults to os.Stderr if nil) + LogOutput io.Writer + // LogLevel sets the logging level (defaults to info if empty) + LogLevel string + // NoUserspace disables the userspace networking mode. Needs admin/root privileges + NoUserspace bool + // ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted. + ConfigPath string + // StatePath is the path to the netbird state file + StatePath string + // DisableClientRoutes disables the client routes + DisableClientRoutes bool +} + +// New creates a new netbird embedded client +func New(opts Options) (*Client, error) { + if opts.LogOutput != nil { + logrus.SetOutput(opts.LogOutput) + } + + if opts.LogLevel != "" { + level, err := logrus.ParseLevel(opts.LogLevel) + if err != nil { + return nil, fmt.Errorf("parse log level: %w", err) + } + logrus.SetLevel(level) + } + + if !opts.NoUserspace { + if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil { + return nil, fmt.Errorf("setenv: %w", err) + } + if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil { + return nil, fmt.Errorf("setenv: %w", err) + } + } + + if opts.StatePath != "" { + // TODO: Disable state if path not provided + if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil { + return nil, fmt.Errorf("setenv: %w", err) + } + } + + t := true + var config *internal.Config + var err error + input := internal.ConfigInput{ + ConfigPath: opts.ConfigPath, + ManagementURL: opts.ManagementURL, + PreSharedKey: &opts.PreSharedKey, + DisableServerRoutes: &t, + DisableClientRoutes: &opts.DisableClientRoutes, + } + if opts.ConfigPath != "" { + config, err = internal.UpdateOrCreateConfig(input) + } else { + config, err = internal.CreateInMemoryConfig(input) + } + if err != nil { + return nil, fmt.Errorf("create config: %w", err) + } + + return &Client{ + deviceName: opts.DeviceName, + setupKey: opts.SetupKey, + config: config, + }, nil +} + +// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs. +// Pass a context with a deadline to limit the time spent waiting for the engine to start. +func (c *Client) Start(startCtx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.cancel != nil { + return ErrClientAlreadyStarted + } + + ctx := internal.CtxInitState(context.Background()) + // nolint:staticcheck + ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) + if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { + return fmt.Errorf("login: %w", err) + } + + recorder := peer.NewRecorder(c.config.ManagementURL.String()) + client := internal.NewConnectClient(ctx, c.config, recorder) + + // either startup error (permanent backoff err) or nil err (successful engine up) + // TODO: make after-startup backoff err available + run := make(chan error, 1) + go func() { + if err := client.Run(run); err != nil { + run <- err + } + }() + + select { + case <-startCtx.Done(): + if stopErr := client.Stop(); stopErr != nil { + return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) + } + return startCtx.Err() + case err := <-run: + if err != nil { + if stopErr := client.Stop(); stopErr != nil { + return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err) + } + return fmt.Errorf("startup: %w", err) + } + } + + c.connect = client + + return nil +} + +// Stop gracefully stops the client. +// Pass a context with a deadline to limit the time spent waiting for the engine to stop. +func (c *Client) Stop(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.connect == nil { + return ErrClientNotStarted + } + + done := make(chan error, 1) + go func() { + done <- c.connect.Stop() + }() + + select { + case <-ctx.Done(): + c.cancel = nil + return ctx.Err() + case err := <-done: + c.cancel = nil + if err != nil { + return fmt.Errorf("stop: %w", err) + } + return nil + } +} + +// Dial dials a network address in the netbird network. +// Not applicable if the userspace networking mode is disabled. +func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { + c.mu.Lock() + connect := c.connect + if connect == nil { + c.mu.Unlock() + return nil, ErrClientNotStarted + } + c.mu.Unlock() + + engine := connect.Engine() + if engine == nil { + return nil, errors.New("engine not started") + } + + nsnet, err := engine.GetNet() + if err != nil { + return nil, fmt.Errorf("get net: %w", err) + } + + return nsnet.DialContext(ctx, network, address) +} + +// ListenTCP listens on the given address in the netbird network +// Not applicable if the userspace networking mode is disabled. +func (c *Client) ListenTCP(address string) (net.Listener, error) { + nsnet, addr, err := c.getNet() + if err != nil { + return nil, err + } + + _, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("split host port: %w", err) + } + listenAddr := fmt.Sprintf("%s:%s", addr, port) + + tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, fmt.Errorf("resolve: %w", err) + } + return nsnet.ListenTCP(tcpAddr) +} + +// ListenUDP listens on the given address in the netbird network +// Not applicable if the userspace networking mode is disabled. +func (c *Client) ListenUDP(address string) (net.PacketConn, error) { + nsnet, addr, err := c.getNet() + if err != nil { + return nil, err + } + + _, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("split host port: %w", err) + } + listenAddr := fmt.Sprintf("%s:%s", addr, port) + + udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return nil, fmt.Errorf("resolve: %w", err) + } + + return nsnet.ListenUDP(udpAddr) +} + +// NewHTTPClient returns a configured http.Client that uses the netbird network for requests. +// Not applicable if the userspace networking mode is disabled. +func (c *Client) NewHTTPClient() *http.Client { + transport := &http.Transport{ + DialContext: c.Dial, + } + + return &http.Client{ + Transport: transport, + } +} + +func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { + c.mu.Lock() + connect := c.connect + if connect == nil { + c.mu.Unlock() + return nil, netip.Addr{}, errors.New("client not started") + } + c.mu.Unlock() + + engine := connect.Engine() + if engine == nil { + return nil, netip.Addr{}, errors.New("engine not started") + } + + addr, err := engine.Address() + if err != nil { + return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err) + } + + nsnet, err := engine.GetNet() + if err != nil { + return nil, netip.Addr{}, fmt.Errorf("get net: %w", err) + } + + return nsnet, addr, nil +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 5bb225ccd..50f48a5c4 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -173,8 +173,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe stateful: !disableConntrack, logger: nblog.NewFromLogrus(log.StandardLogger()), netstack: netstack.IsEnabled(), - // default true for non-netstack, for netstack only if explicitly enabled - localForwarding: !netstack.IsEnabled() || enableLocalForwarding, + localForwarding: enableLocalForwarding, } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { @@ -647,11 +646,6 @@ func (m *Manager) dropFilter(packetData []byte) bool { // handleLocalTraffic handles local traffic. // If it returns true, the packet should be dropped. func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { - if !m.localForwarding { - m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP) - return true - } - if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) { m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s", srcIP, dstIP) @@ -660,22 +654,29 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData // if running in netstack mode we need to pass this to the forwarder if m.netstack { - m.handleNetstackLocalTraffic(packetData) - - // don't process this packet further - return true + return m.handleNetstackLocalTraffic(packetData) } return false } -func (m *Manager) handleNetstackLocalTraffic(packetData []byte) { + +func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { + if !m.localForwarding { + // pass to virtual tcp/ip stack to be picked up by listeners + return false + } + if m.forwarder == nil { - return + m.logger.Trace("Dropping local packet (forwarder not initialized)") + return true } if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject local packet: %v", err) } + + // don't process this packet further + return true } // handleRoutedTraffic handles routed traffic. diff --git a/client/iface/device.go b/client/iface/device.go index 2a170adfb..86e9dab4b 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -3,6 +3,8 @@ package iface import ( + "golang.zx2c4.com/wireguard/tun/netstack" + wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/iface/bind" @@ -18,4 +20,5 @@ type WGTunDevice interface { Close() error FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device + GetNet() *netstack.Net } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 772722b83..55081e181 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -9,6 +9,7 @@ import ( "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -130,6 +131,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +func (t *WGTunDevice) GetNet() *netstack.Net { + return nil +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index fe7ed1752..1a5635ff2 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -143,3 +144,7 @@ func (t *TunDevice) assignAddr() error { } return nil } + +func (t *TunDevice) GetNet() *netstack.Net { + return nil +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index cdabd2c85..b106d475c 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -10,6 +10,7 @@ import ( "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -131,3 +132,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } + +func (t *TunDevice) GetNet() *netstack.Net { + return nil +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 3314b576b..fe1d1147f 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -10,6 +10,7 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -165,3 +166,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { func (t *TunKernelDevice) assignAddr() error { return t.link.assignAddr(t.address) } + +func (t *TunKernelDevice) GetNet() *netstack.Net { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index c7d297187..0cb02fd19 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -8,10 +8,12 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/netstack" + nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + nbnet "github.com/netbirdio/netbird/util/net" ) type TunNetstackDevice struct { @@ -25,9 +27,11 @@ type TunNetstackDevice struct { device *device.Device filteredDevice *FilteredDevice - nsTun *netstack.NetStackTun + nsTun *nbnetstack.NetStackTun udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer + + net *netstack.Net } func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { @@ -43,13 +47,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m } func (t *TunNetstackDevice) Create() (WGConfigurer, error) { - log.Info("create netstack tun interface") - t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) - tunIface, err := t.nsTun.Create() + log.Info("create nbnetstack tun interface") + + // TODO: get from service listener runtime IP + dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) + log.Debugf("netstack using address: %s", t.address.IP) + t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) + log.Debugf("netstack using dns address: %s", dnsAddr) + tunIface, net, err := t.nsTun.Create() if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } t.filteredDevice = newDeviceFilter(tunIface) + t.net = net t.device = device.NewDevice( t.filteredDevice, @@ -122,3 +132,7 @@ func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { func (t *TunNetstackDevice) Device() *device.Device { return t.device } + +func (t *TunNetstackDevice) GetNet() *netstack.Net { + return t.net +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4ac87aecb..07570617a 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -135,3 +136,7 @@ func (t *USPDevice) assignAddr() error { return link.assignAddr(t.address) } + +func (t *USPDevice) GetNet() *netstack.Net { + return nil +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index e603d7696..0fd1b3326 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -8,6 +8,7 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "github.com/netbirdio/netbird/client/iface/bind" @@ -174,3 +175,7 @@ func (t *TunDevice) assignAddr() error { log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) } + +func (t *TunDevice) GetNet() *netstack.Net { + return nil +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 028f6fa7d..5cbeb70f8 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -3,6 +3,8 @@ package iface import ( wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -16,4 +18,5 @@ type WGTunDevice interface { Close() error FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device + GetNet() *netstack.Net } diff --git a/client/iface/iface.go b/client/iface/iface.go index 64219975f..8056dd9a6 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" wgdevice "golang.zx2c4.com/wireguard/device" @@ -241,3 +242,11 @@ func (w *WGIface) waitUntilRemoved() error { } } } + +// GetNet returns the netstack.Net for the netstack device +func (w *WGIface) GetNet() *netstack.Net { + w.mu.Lock() + defer w.mu.Unlock() + + return w.tun.GetNet() +} diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index 5f57bc821..f92a8cfc8 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -5,6 +5,7 @@ import ( "time" wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -34,6 +35,7 @@ type MockWGIface struct { GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy + GetNetFunc func() *netstack.Net } func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { @@ -115,3 +117,7 @@ func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetProxy() wgproxy.Proxy { return m.GetProxyFunc() } + +func (m *MockWGIface) GetNet() *netstack.Net { + return m.GetNetFunc() +} diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index 472ab45f9..2b919ac9e 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -7,6 +7,7 @@ import ( "time" wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -35,4 +36,5 @@ type IWGIface interface { GetDevice() *device.FilteredDevice GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) + GetNet() *netstack.Net } diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index c9183cafd..cac096b54 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -5,6 +5,7 @@ import ( "time" wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -34,4 +35,5 @@ type IWGIface interface { GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) + GetNet() *netstack.Net } diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index 09889a57e..cdbf975b1 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -8,9 +8,11 @@ import ( log "github.com/sirupsen/logrus" ) +const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE" + // IsEnabled todo: move these function to cmd layer func IsEnabled() bool { - return os.Getenv("NB_USE_NETSTACK_MODE") == "true" + return os.Getenv(EnvUseNetstackMode) == "true" } func ListenAddr() string { diff --git a/client/iface/netstack/tun.go b/client/iface/netstack/tun.go index c180e4ef5..01f19875e 100644 --- a/client/iface/netstack/tun.go +++ b/client/iface/netstack/tun.go @@ -1,15 +1,22 @@ package netstack import ( + "fmt" + "net" "net/netip" + "os" + "strconv" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/netstack" ) +const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" + type NetStackTun struct { //nolint:revive - address string + address net.IP + dnsAddress net.IP mtu int listenAddress string @@ -17,29 +24,48 @@ type NetStackTun struct { //nolint:revive tundev tun.Device } -func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun { +func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { return &NetStackTun{ address: address, + dnsAddress: dnsAddress, mtu: mtu, listenAddress: listenAddress, } } -func (t *NetStackTun) Create() (tun.Device, error) { +func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { + addr, ok := netip.AddrFromSlice(t.address) + if !ok { + return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address) + } + + dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress) + if !ok { + return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress) + } + nsTunDev, tunNet, err := netstack.CreateNetTUN( - []netip.Addr{netip.MustParseAddr(t.address)}, - []netip.Addr{}, + []netip.Addr{addr.Unmap()}, + []netip.Addr{dnsAddr.Unmap()}, t.mtu) if err != nil { - return nil, err + return nil, nil, err } t.tundev = nsTunDev + skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) + if err != nil { + log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err) + } + if skipProxy { + return nsTunDev, tunNet, nil + } + dialer := NewNSDialer(tunNet) t.proxy, err = NewSocks5(dialer) if err != nil { _ = t.tundev.Close() - return nil, err + return nil, nil, err } go func() { @@ -49,7 +75,7 @@ func (t *NetStackTun) Create() (tun.Device, error) { } }() - return nsTunDev, nil + return nsTunDev, tunNet, nil } func (t *NetStackTun) Close() error { diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 729b90cc0..250f3ab2e 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -2,7 +2,6 @@ package dns import ( "fmt" - "math/big" "net" "sync" @@ -10,6 +9,8 @@ import ( "github.com/google/gopacket/layers" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" ) type ServiceViaMemory struct { @@ -27,7 +28,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1), + runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), runtimePort: defaultPort, } return s @@ -118,22 +119,3 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil } - -func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string { - // Calculate the last IP in the CIDR range - var endIP net.IP - for i := 0; i < len(network.IP); i++ { - endIP = append(endIP, network.IP[i]|^network.Mask[i]) - } - - // convert to big.Int - endInt := big.NewInt(0) - endInt.SetBytes(endIP) - - // subtract fromEnd from the last ip - fromEndBig := big.NewInt(int64(fromEnd)) - resultInt := big.NewInt(0) - resultInt.Sub(endInt, fromEndBig) - - return net.IP(resultInt.Bytes()).String() -} diff --git a/client/internal/dns/service_memory_test.go b/client/internal/dns/service_memory_test.go index bea4f4ce8..244adfaef 100644 --- a/client/internal/dns/service_memory_test.go +++ b/client/internal/dns/service_memory_test.go @@ -3,6 +3,8 @@ package dns import ( "net" "testing" + + nbnet "github.com/netbirdio/netbird/util/net" ) func TestGetLastIPFromNetwork(t *testing.T) { @@ -23,7 +25,7 @@ func TestGetLastIPFromNetwork(t *testing.T) { return } - lastIP := getLastIPFromNetwork(ipnet, 1) + lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String() if lastIP != tt.ip { t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 14e0d348f..d590c0db6 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -19,6 +19,7 @@ import ( "github.com/pion/ice/v3" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/protobuf/proto" @@ -28,7 +29,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" + nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" @@ -724,7 +725,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { // start SSH server if it wasn't running if isNil(e.sshServer) { listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) - if netstack.IsEnabled() { + if nbnetstack.IsEnabled() { listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) } // nil sshServer means it has not yet been started @@ -1716,6 +1717,37 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { } } +func (e *Engine) GetNet() (*netstack.Net, error) { + e.syncMsgMux.Lock() + intf := e.wgInterface + e.syncMsgMux.Unlock() + if intf == nil { + return nil, errors.New("wireguard interface not initialized") + } + + nsnet := intf.GetNet() + if nsnet == nil { + return nil, errors.New("failed to get netstack") + } + return nsnet, nil +} + +func (e *Engine) Address() (netip.Addr, error) { + e.syncMsgMux.Lock() + intf := e.wgInterface + e.syncMsgMux.Unlock() + if intf == nil { + return netip.Addr{}, errors.New("wireguard interface not initialized") + } + + addr := e.wgInterface.Address() + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + return netip.Addr{}, errors.New("failed to convert address to netip.Addr") + } + return ip.Unmap(), nil +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks { diff --git a/util/net/net.go b/util/net/net.go index 403aa87e7..7b43b952f 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,6 +1,7 @@ package net import ( + "math/big" "net" "github.com/google/uuid" @@ -26,3 +27,22 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } + +func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP { + // Calculate the last IP in the CIDR range + var endIP net.IP + for i := 0; i < len(network.IP); i++ { + endIP = append(endIP, network.IP[i]|^network.Mask[i]) + } + + // convert to big.Int + endInt := big.NewInt(0) + endInt.SetBytes(endIP) + + // subtract fromEnd from the last ip + fromEndBig := big.NewInt(int64(fromEnd)) + resultInt := big.NewInt(0) + resultInt.Sub(endInt, fromEndBig) + + return resultInt.Bytes() +}