Files
netbird/client/embed/embed.go
Zoltan Papp aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00

294 lines
7.5 KiB
Go

package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}