From 891ba277b101c8455a993dd6e7d7e56d3d607fe8 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 17 Mar 2023 10:37:27 +0100 Subject: [PATCH] Mobile (#735) Initial modification to support mobile client Export necessary interfaces for Android framework --- client/android/client.go | 121 +++++ client/android/login.go | 173 +++++++ client/android/peer_notifier.go | 37 ++ client/android/preferences.go | 78 +++ client/android/preferences_test.go | 120 +++++ client/cmd/up.go | 2 +- client/internal/config.go | 42 +- client/internal/connect.go | 9 +- client/internal/device_auth.go | 1 + client/internal/dns/local.go | 2 + client/internal/dns/server.go | 464 +---------------- client/internal/dns/server_android.go | 32 ++ client/internal/dns/server_nonandroid.go | 465 ++++++++++++++++++ client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 5 +- client/internal/engine_test.go | 6 +- client/internal/login.go | 141 ++++-- client/internal/peer/conn.go | 7 +- client/internal/peer/listener.go | 9 + client/internal/peer/notifier.go | 124 +++++ client/internal/peer/notifier_test.go | 32 ++ client/internal/peer/status.go | 45 +- client/internal/routemanager/manager.go | 183 +------ .../internal/routemanager/manager_android.go | 31 ++ .../routemanager/manager_nonandroid.go | 186 +++++++ client/internal/routemanager/manager_test.go | 2 +- .../internal/routemanager/systemops_test.go | 2 +- client/server/server.go | 4 +- client/system/info.go | 3 + client/system/info_android.go | 63 +++ client/system/info_linux.go | 3 + formatter/formatter.go | 14 +- formatter/logcat.go | 48 ++ formatter/logcat_test.go | 28 ++ formatter/set.go | 9 +- go.mod | 2 +- iface/iface.go | 260 ++-------- iface/iface_android.go | 22 + iface/iface_nonandroid.go | 22 + iface/iface_test.go | 22 +- iface/iface_windows.go | 67 +-- iface/ipc_parser_android.go | 60 +++ iface/module.go | 4 +- iface/module_linux.go | 2 + iface/{ifacename.go => name.go} | 0 iface/{ifacename_darwin.go => name_darwin.go} | 0 iface/tun.go | 6 + iface/tun_adapter.go | 7 + iface/tun_android.go | 112 +++++ iface/{iface_darwin.go => tun_darwin.go} | 18 +- iface/{iface_linux.go => tun_linux.go} | 64 +-- iface/{iface_unix.go => tun_unix.go} | 69 +-- iface/tun_windows.go | 93 ++++ iface/wg_configurer_android.go | 114 +++++ iface/wg_configurer_nonandroid.go | 208 ++++++++ 55 files changed, 2562 insertions(+), 1083 deletions(-) create mode 100644 client/android/client.go create mode 100644 client/android/login.go create mode 100644 client/android/peer_notifier.go create mode 100644 client/android/preferences.go create mode 100644 client/android/preferences_test.go create mode 100644 client/internal/dns/server_android.go create mode 100644 client/internal/dns/server_nonandroid.go create mode 100644 client/internal/peer/listener.go create mode 100644 client/internal/peer/notifier.go create mode 100644 client/internal/peer/notifier_test.go create mode 100644 client/internal/routemanager/manager_android.go create mode 100644 client/internal/routemanager/manager_nonandroid.go create mode 100644 client/system/info_android.go create mode 100644 formatter/logcat.go create mode 100644 formatter/logcat_test.go create mode 100644 iface/iface_android.go create mode 100644 iface/iface_nonandroid.go create mode 100644 iface/ipc_parser_android.go rename iface/{ifacename.go => name.go} (100%) rename iface/{ifacename_darwin.go => name_darwin.go} (100%) create mode 100644 iface/tun.go create mode 100644 iface/tun_adapter.go create mode 100644 iface/tun_android.go rename iface/{iface_darwin.go => tun_darwin.go} (56%) rename iface/{iface_linux.go => tun_linux.go} (60%) rename iface/{iface_unix.go => tun_unix.go} (53%) create mode 100644 iface/tun_windows.go create mode 100644 iface/wg_configurer_android.go create mode 100644 iface/wg_configurer_nonandroid.go diff --git a/client/android/client.go b/client/android/client.go new file mode 100644 index 000000000..778c3d15a --- /dev/null +++ b/client/android/client.go @@ -0,0 +1,121 @@ +package android + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/iface" +) + +// ConnectionListener export internal Listener for mobile +type ConnectionListener interface { + peer.Listener +} + +// TunAdapter export internal TunAdapter for mobile +type TunAdapter interface { + iface.TunAdapter +} + +func init() { + formatter.SetLogcatFormatter(log.StandardLogger()) +} + +// Client struct manage the life circle of background service +type Client struct { + cfgFile string + tunAdapter iface.TunAdapter + recorder *peer.Status + ctxCancel context.CancelFunc + ctxCancelLock *sync.Mutex + deviceName string +} + +// NewClient instantiate a new Client +func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client { + lvl, _ := log.ParseLevel("trace") + log.SetLevel(lvl) + + return &Client{ + cfgFile: cfgFile, + deviceName: deviceName, + tunAdapter: tunAdapter, + recorder: peer.NewRecorder(""), + ctxCancelLock: &sync.Mutex{}, + } +} + +// Run start the internal client. It is a blocker function +func (c *Client) Run(urlOpener URLOpener) error { + cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + ConfigPath: c.cfgFile, + }) + if err != nil { + return err + } + c.recorder.UpdateManagementAddress(cfg.ManagementURL.String()) + + var ctx context.Context + //nolint + ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) + c.ctxCancelLock.Lock() + ctx, c.ctxCancel = context.WithCancel(ctxWithValues) + defer c.ctxCancel() + c.ctxCancelLock.Unlock() + + auth := NewAuthWithConfig(ctx, cfg) + err = auth.Login(urlOpener) + if err != nil { + return err + } + + // todo do not throw error in case of cancelled context + ctx = internal.CtxInitState(ctx) + return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter) +} + +// Stop the internal client and free the resources +func (c *Client) Stop() { + c.ctxCancelLock.Lock() + defer c.ctxCancelLock.Unlock() + if c.ctxCancel == nil { + return + } + + c.ctxCancel() +} + +// PeersList return with the list of the PeerInfos +func (c *Client) PeersList() *PeerInfoArray { + + fullStatus := c.recorder.GetFullStatus() + + peerInfos := make([]PeerInfo, len(fullStatus.Peers)) + for n, p := range fullStatus.Peers { + pi := PeerInfo{ + p.IP, + p.FQDN, + p.ConnStatus.String(), + p.Direct, + } + peerInfos[n] = pi + } + + return &PeerInfoArray{items: peerInfos} +} + +// AddConnectionListener add new network connection listener +func (c *Client) AddConnectionListener(listener ConnectionListener) { + c.recorder.AddConnectionListener(listener) +} + +// RemoveConnectionListener remove connection listener +func (c *Client) RemoveConnectionListener(listener ConnectionListener) { + c.recorder.RemoveConnectionListener(listener) +} diff --git a/client/android/login.go b/client/android/login.go new file mode 100644 index 000000000..e4cb5513d --- /dev/null +++ b/client/android/login.go @@ -0,0 +1,173 @@ +package android + +import ( + "context" + "fmt" + "github.com/cenkalti/backoff/v4" + "github.com/netbirdio/netbird/client/cmd" + "time" + + "github.com/netbirdio/netbird/client/internal" + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" +) + +// URLOpener it is a callback interface. The Open function will be triggered if +// the backend want to show an url for the user +type URLOpener interface { + Open(string) +} + +// Auth can register or login new client +type Auth struct { + ctx context.Context + config *internal.Config + cfgPath string +} + +// NewAuth instantiate Auth struct and validate the management URL +func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { + inputCfg := internal.ConfigInput{ + ManagementURL: mgmURL, + } + + cfg, err := internal.CreateInMemoryConfig(inputCfg) + if err != nil { + return nil, err + } + + return &Auth{ + ctx: context.Background(), + config: cfg, + cfgPath: cfgPath, + }, nil +} + +// NewAuthWithConfig instantiate Auth based on existing config +func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { + return &Auth{ + ctx: ctx, + config: config, + } +} + +// LoginAndSaveConfigIfSSOSupported test the connectivity with the management server. +// If the SSO is supported than save the configuration. Return with the SSO login is supported or not. +func (a *Auth) LoginAndSaveConfigIfSSOSupported() (bool, error) { + var needsLogin bool + err := a.withBackOff(a.ctx, func() (err error) { + needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey) + return + }) + if err != nil { + return false, fmt.Errorf("backoff cycle failed: %v", err) + } + if !needsLogin { + return false, nil + } + err = internal.WriteOutConfig(a.cfgPath, a.config) + return needsLogin, err +} + +// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. +func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error { + err := a.withBackOff(a.ctx, func() error { + err := internal.Login(a.ctx, a.config, setupKey, "") + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + return nil + } + return err + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + return internal.WriteOutConfig(a.cfgPath, a.config) +} + +// Login try register the client on the server +func (a *Auth) Login(urlOpener URLOpener) error { + var needsLogin bool + + // check if we need to generate JWT token + err := a.withBackOff(a.ctx, func() (err error) { + needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey) + return + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + jwtToken := "" + if needsLogin { + tokenInfo, err := a.foregroundGetTokenInfo(urlOpener) + if err != nil { + return fmt.Errorf("interactive sso login failed: %v", err) + } + jwtToken = tokenInfo.AccessToken + } + + err = a.withBackOff(a.ctx, func() error { + err := internal.Login(a.ctx, a.config, "", jwtToken) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + return nil + } + return err + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + return nil +} + +func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) { + providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if err != nil { + s, ok := gstatus.FromError(err) + if ok && s.Code() == codes.NotFound { + return nil, fmt.Errorf("no SSO provider returned from management. " + + "If you are using hosting Netbird see documentation at " + + "https://github.com/netbirdio/netbird/tree/main/management for details") + } else if ok && s.Code() == codes.Unimplemented { + return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ + "please update your servver or use Setup Keys to login", a.config.ManagementURL) + } else { + return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err) + } + } + + hostedClient := internal.NewHostedDeviceFlow( + providerConfig.ProviderConfig.Audience, + providerConfig.ProviderConfig.ClientID, + providerConfig.ProviderConfig.TokenEndpoint, + providerConfig.ProviderConfig.DeviceAuthEndpoint, + ) + + flowInfo, err := hostedClient.RequestDeviceCode(context.TODO()) + if err != nil { + return nil, fmt.Errorf("getting a request device code failed: %v", err) + } + + go urlOpener.Open(flowInfo.VerificationURIComplete) + + waitTimeout := time.Duration(flowInfo.ExpiresIn) + waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second) + defer cancel() + tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo) + if err != nil { + return nil, fmt.Errorf("waiting for browser login failed: %v", err) + } + + return &tokenInfo, nil +} + +func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { + return backoff.RetryNotify( + bf, + backoff.WithContext(cmd.CLIBackOffSettings, ctx), + func(err error, duration time.Duration) { + log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) + }) +} diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go new file mode 100644 index 000000000..b998fa00e --- /dev/null +++ b/client/android/peer_notifier.go @@ -0,0 +1,37 @@ +package android + +// PeerInfo describe information about the peers. It designed for the UI usage +type PeerInfo struct { + IP string + FQDN string + ConnStatus string // Todo replace to enum + Direct bool +} + +// PeerInfoCollection made for Java layer to get non default types as collection +type PeerInfoCollection interface { + Add(s string) PeerInfoCollection + Get(i int) string + Size() int +} + +// PeerInfoArray is the implementation of the PeerInfoCollection +type PeerInfoArray struct { + items []PeerInfo +} + +// Add new PeerInfo to the collection +func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray { + array.items = append(array.items, s) + return array +} + +// Get return an element of the collection +func (array PeerInfoArray) Get(i int) *PeerInfo { + return &array.items[i] +} + +// Size return with the size of the collection +func (array PeerInfoArray) Size() int { + return len(array.items) +} diff --git a/client/android/preferences.go b/client/android/preferences.go new file mode 100644 index 000000000..08485eafc --- /dev/null +++ b/client/android/preferences.go @@ -0,0 +1,78 @@ +package android + +import ( + "github.com/netbirdio/netbird/client/internal" +) + +// Preferences export a subset of the internal config for gomobile +type Preferences struct { + configInput internal.ConfigInput +} + +// NewPreferences create new Preferences instance +func NewPreferences(configPath string) *Preferences { + ci := internal.ConfigInput{ + ConfigPath: configPath, + } + return &Preferences{ci} +} + +// GetManagementURL read url from config file +func (p *Preferences) GetManagementURL() (string, error) { + if p.configInput.ManagementURL != "" { + return p.configInput.ManagementURL, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.ManagementURL.String(), err +} + +// SetManagementURL store the given url and wait for commit +func (p *Preferences) SetManagementURL(url string) { + p.configInput.ManagementURL = url +} + +// GetAdminURL read url from config file +func (p *Preferences) GetAdminURL() (string, error) { + if p.configInput.AdminURL != "" { + return p.configInput.AdminURL, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.AdminURL.String(), err +} + +// SetAdminURL store the given url and wait for commit +func (p *Preferences) SetAdminURL(url string) { + p.configInput.AdminURL = url +} + +// GetPreSharedKey read preshared key from config file +func (p *Preferences) GetPreSharedKey() (string, error) { + if p.configInput.PreSharedKey != nil { + return *p.configInput.PreSharedKey, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.PreSharedKey, err +} + +// SetPreSharedKey store the given key and wait for commit +func (p *Preferences) SetPreSharedKey(key string) { + p.configInput.PreSharedKey = &key +} + +// Commit write out the changes into config file +func (p *Preferences) Commit() error { + _, err := internal.UpdateOrCreateConfig(p.configInput) + return err +} diff --git a/client/android/preferences_test.go b/client/android/preferences_test.go new file mode 100644 index 000000000..73c8692d7 --- /dev/null +++ b/client/android/preferences_test.go @@ -0,0 +1,120 @@ +package android + +import ( + "path/filepath" + "testing" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestPreferences_DefaultValues(t *testing.T) { + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + defaultVar, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read default value: %s", err) + } + + if defaultVar != internal.DefaultAdminURL { + t.Errorf("invalid default admin url: %s", defaultVar) + } + + defaultVar, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read default management URL: %s", err) + } + + if defaultVar != internal.DefaultManagementURL { + t.Errorf("invalid default management url: %s", defaultVar) + } + + var preSharedKey string + preSharedKey, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read default preshared key: %s", err) + } + + if preSharedKey != "" { + t.Errorf("invalid preshared key: %s", preSharedKey) + } +} + +func TestPreferences_ReadUncommitedValues(t *testing.T) { + exampleString := "exampleString" + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + + p.SetAdminURL(exampleString) + resp, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read admin url: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected admin url: %s", resp) + } + + p.SetManagementURL(exampleString) + resp, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read managmenet url: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected managemenet url: %s", resp) + } + + p.SetPreSharedKey(exampleString) + resp, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read preshared key: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected preshared key: %s", resp) + } +} + +func TestPreferences_Commit(t *testing.T) { + exampleURL := "https://myurl.com:443" + examplePresharedKey := "topsecret" + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + + p.SetAdminURL(exampleURL) + p.SetManagementURL(exampleURL) + p.SetPreSharedKey(examplePresharedKey) + + err := p.Commit() + if err != nil { + t.Fatalf("failed to save changes: %s", err) + } + + p = NewPreferences(cfgFile) + resp, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read admin url: %s", err) + } + + if resp != exampleURL { + t.Errorf("unexpected admin url: %s", resp) + } + + resp, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read managmenet url: %s", err) + } + + if resp != exampleURL { + t.Errorf("unexpected managemenet url: %s", resp) + } + + resp, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read preshared key: %s", err) + } + + if resp != examplePresharedKey { + t.Errorf("unexpected preshared key: %s", resp) + } +} diff --git a/client/cmd/up.go b/client/cmd/up.go index afe69b68a..5bbdab690 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String())) + return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil) } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { diff --git a/client/internal/config.go b/client/internal/config.go index a3ff09979..8bee88cf0 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -73,6 +73,25 @@ type Config struct { CustomDNSAddress string } +// ReadConfig read config file and return with Config. If it is not exists create a new with default values +func ReadConfig(configPath string) (*Config, error) { + if configFileIsExists(configPath) { + config := &Config{} + if _, err := util.ReadJson(configPath, config); err != nil { + return nil, err + } + return config, nil + } + + cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) + if err != nil { + return nil, err + } + + err = WriteOutConfig(configPath, cfg) + return cfg, err +} + // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { if !configFileIsExists(input.ConfigPath) { @@ -86,7 +105,12 @@ func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if !configFileIsExists(input.ConfigPath) { log.Infof("generating new config %s", input.ConfigPath) - return createNewConfig(input) + cfg, err := createNewConfig(input) + if err != nil { + return nil, err + } + err = WriteOutConfig(input.ConfigPath, cfg) + return cfg, err } if isPreSharedKeyHidden(input.PreSharedKey) { @@ -95,6 +119,16 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { return update(input) } +// CreateInMemoryConfig generate a new config but do not write out it to the store +func CreateInMemoryConfig(input ConfigInput) (*Config, error) { + return createNewConfig(input) +} + +// WriteOutConfig write put the prepared config to the given path +func WriteOutConfig(path string, config *Config) error { + return util.WriteJson(path, config) +} + // createNewConfig creates a new config generating a new Wireguard key and saving to file func createNewConfig(input ConfigInput) (*Config, error) { wgKey := generateKey() @@ -146,12 +180,6 @@ func createNewConfig(input ConfigInput) (*Config, error) { } config.IFaceBlackList = defaultInterfaceBlacklist - - err = util.WriteJson(input.ConfigPath, config) - if err != nil { - return nil, err - } - return config, nil } diff --git a/client/internal/connect.go b/client/internal/connect.go index 4a3d052b7..eeb0e640e 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -22,7 +22,7 @@ import ( ) // RunClient with main logic. -func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error { +func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error { backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -60,6 +60,8 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) statusRecorder.MarkManagementDisconnected() + statusRecorder.ClientStart() + defer statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle select { @@ -144,7 +146,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig) + engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter) if err != nil { log.Error(err) return wrapErr(err) @@ -191,11 +193,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) { engineConf := &EngineConfig{ WgIfaceName: config.WgIface, WgAddr: peerConfig.Address, + TunAdapter: tunAdapter, IFaceBlackList: config.IFaceBlackList, DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 2fab17188..d2396242b 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -57,6 +57,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU return DeviceAuthorizationFlow{}, err } log.Debugf("connected to the Management service %s", mgmURL.String()) + defer func() { err = mgmClient.Close() if err != nil { diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 9a9857b00..e9fcc37eb 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -8,6 +8,8 @@ import ( "sync" ) +type registrationMap map[string]struct{} + type localResolver struct { registeredMap registrationMap records sync.Map diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 7a9d5301e..53006b164 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -1,27 +1,6 @@ package dns -import ( - "context" - "fmt" - "net" - "net/netip" - "runtime" - "sync" - "time" - - "github.com/miekg/dns" - "github.com/mitchellh/hashstructure/v2" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - log "github.com/sirupsen/logrus" -) - -const ( - defaultPort = 53 - customPort = 5053 - defaultIP = "127.0.0.1" - customIP = "127.0.0.153" -) +import nbdns "github.com/netbirdio/netbird/dns" // Server is a dns server interface type Server interface { @@ -29,444 +8,3 @@ type Server interface { Stop() UpdateDNSServer(serial uint64, update nbdns.Config) error } - -// DefaultServer dns server object -type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc - upstreamCtxCancel context.CancelFunc - mux sync.Mutex - server *dns.Server - dnsMux *dns.ServeMux - dnsMuxMap registrationMap - localResolver *localResolver - wgInterface *iface.WGIface - hostManager hostManager - updateSerial uint64 - listenerIsRunning bool - runtimePort int - runtimeIP string - previousConfigHash uint64 - currentConfig hostDNSConfig - customAddress *netip.AddrPort -} - -type registrationMap map[string]struct{} - -type muxUpdate struct { - domain string - handler dns.Handler -} - -// NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) { - mux := dns.NewServeMux() - - dnsServer := &dns.Server{ - Net: "udp", - Handler: mux, - UDPSize: 65535, - } - - ctx, stop := context.WithCancel(ctx) - - var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) - if err != nil { - stop() - return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) - } - addrPort = &parsedAddrPort - } - - defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - server: dnsServer, - dnsMux: mux, - dnsMuxMap: make(registrationMap), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, - wgInterface: wgInterface, - runtimePort: defaultPort, - customAddress: addrPort, - } - - hostmanager, err := newHostManager(wgInterface) - if err != nil { - stop() - return nil, err - } - defaultServer.hostManager = hostmanager - return defaultServer, err -} - -// Start runs the listener in a go routine -func (s *DefaultServer) Start() { - if s.customAddress != nil { - s.runtimeIP = s.customAddress.Addr().String() - s.runtimePort = int(s.customAddress.Port()) - } else { - ip, port, err := s.getFirstListenerAvailable() - if err != nil { - log.Error(err) - return - } - s.runtimeIP = ip - s.runtimePort = port - } - - s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) - - log.Debugf("starting dns on %s", s.server.Addr) - - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) - - err := s.server.ListenAndServe() - if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) - } - }() -} - -func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { - ips := []string{defaultIP, customIP} - if runtime.GOOS != "darwin" && s.wgInterface != nil { - ips = append([]string{s.wgInterface.Address().IP.String()}, ips...) - } - ports := []int{defaultPort, customPort} - for _, port := range ports { - for _, ip := range ips { - addrString := fmt.Sprintf("%s:%d", ip, port) - udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) - probeListener, err := net.ListenUDP("udp", udpAddr) - if err == nil { - err = probeListener.Close() - if err != nil { - log.Errorf("got an error closing the probe listener, error: %s", err) - } - return ip, port, nil - } - log.Warnf("binding dns on %s is not available, error: %s", addrString, err) - } - } - return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) -} - -func (s *DefaultServer) setListenerStatus(running bool) { - s.listenerIsRunning = running -} - -// Stop stops the server -func (s *DefaultServer) Stop() { - s.mux.Lock() - defer s.mux.Unlock() - s.ctxCancel() - - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) - } - - err = s.stopListener() - if err != nil { - log.Error(err) - } -} - -func (s *DefaultServer) stopListener() error { - if !s.listenerIsRunning { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - err := s.server.ShutdownContext(ctx) - if err != nil { - return fmt.Errorf("stopping dns server listener returned an error: %v", err) - } - return nil -} - -// UpdateDNSServer processes an update received from the management service -func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { - select { - case <-s.ctx.Done(): - log.Infof("not updating DNS server as context is closed") - return s.ctx.Err() - default: - if serial < s.updateSerial { - return fmt.Errorf("not applying dns update, error: "+ - "network update is %d behind the last applied update", s.updateSerial-serial) - } - s.mux.Lock() - defer s.mux.Unlock() - - hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ - ZeroNil: true, - IgnoreZeroValue: true, - SlicesAsSets: true, - UseStringer: true, - }) - if err != nil { - log.Errorf("unable to hash the dns configuration update, got error: %s", err) - } - - if s.previousConfigHash == hash { - log.Debugf("not applying the dns configuration update as there is nothing new") - s.updateSerial = serial - return nil - } - - if err := s.applyConfiguration(update); err != nil { - return err - } - - s.updateSerial = serial - s.previousConfigHash = hash - - return nil - } -} - -func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { - // is the service should be disabled, we stop the listener - // and proceed with a regular update to clean up the handlers and records - if !update.ServiceEnable { - err := s.stopListener() - if err != nil { - log.Error(err) - } - } else if !s.listenerIsRunning { - s.Start() - } - - localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - - muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) - - s.updateMux(muxUpdates) - s.updateLocalResolver(localRecords) - s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) - - if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - log.Error(err) - } - - return nil -} - -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { - var muxUpdates []muxUpdate - localRecords := make(map[string]nbdns.SimpleRecord, 0) - - for _, customZone := range customZones { - - if len(customZone.Records) == 0 { - return nil, nil, fmt.Errorf("received an empty list of records") - } - - muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, - }) - - for _, record := range customZone.Records { - var class uint16 = dns.ClassINET - if record.Class != nbdns.DefaultClass { - return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) - } - key := buildRecordKey(record.Name, class, uint16(record.Type)) - localRecords[key] = record - } - } - return muxUpdates, localRecords, nil -} - -func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { - // clean up the previous upstream resolver - if s.upstreamCtxCancel != nil { - s.upstreamCtxCancel() - } - - var muxUpdates []muxUpdate - for _, nsGroup := range nameServerGroups { - if len(nsGroup.NameServers) == 0 { - log.Warn("received a nameserver group with empty nameserver list") - continue - } - - var ctx context.Context - ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx) - - handler := newUpstreamResolver(ctx) - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) - } - - if len(handler.upstreamServers) == 0 { - log.Errorf("received a nameserver group with an invalid nameserver list") - continue - } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it trys to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) - - if nsGroup.Primary { - muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, - }) - continue - } - - if len(nsGroup.Domains) == 0 { - return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") - } - - for _, domain := range nsGroup.Domains { - if domain == "" { - return nil, fmt.Errorf("received a nameserver group with an empty domain element") - } - muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, - }) - } - } - return muxUpdates, nil -} - -func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { - muxUpdateMap := make(registrationMap) - - for _, update := range muxUpdates { - s.registerMux(update.domain, update.handler) - muxUpdateMap[update.domain] = struct{}{} - } - - for key := range s.dnsMuxMap { - _, found := muxUpdateMap[key] - if !found { - s.deregisterMux(key) - } - } - - s.dnsMuxMap = muxUpdateMap -} - -func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { - for key := range s.localResolver.registeredMap { - _, found := update[key] - if !found { - s.localResolver.deleteRecord(key) - } - } - - updatedMap := make(registrationMap) - for key, record := range update { - err := s.localResolver.registerRecord(record) - if err != nil { - log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err) - } - updatedMap[key] = struct{}{} - } - - s.localResolver.registeredMap = updatedMap -} - -func getNSHostPort(ns nbdns.NameServer) string { - return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) -} - -func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) { - s.dnsMux.Handle(pattern, handler) -} - -func (s *DefaultServer) deregisterMux(pattern string) { - s.dnsMux.HandleRemove(pattern) -} - -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, -) (deactivate func(), reactivate func()) { - var removeIndex map[string]int - deactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("temporary deactivate nameservers group due timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.routeAll = false - } - - for i, item := range s.currentConfig.domains { - if _, found := removeIndex[item.domain]; found { - s.currentConfig.domains[i].disabled = true - s.deregisterMux(item.domain) - removeIndex[item.domain] = i - } - } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("fail to apply nameserver deactivation on the host") - } - } - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { - continue - } - s.currentConfig.domains[i].disabled = false - s.registerMux(domain, handler) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.routeAll = true - } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") - } - } - return -} diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go new file mode 100644 index 000000000..dddbc65a2 --- /dev/null +++ b/client/internal/dns/server_android.go @@ -0,0 +1,32 @@ +package dns + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" +) + +// DefaultServer dummy dns server +type DefaultServer struct { +} + +// NewDefaultServer On Android the DNS feature is not supported yet +func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) { + return &DefaultServer{}, nil +} + +// Start dummy implementation +func (s DefaultServer) Start() { + +} + +// Stop dummy implementation +func (s DefaultServer) Stop() { + +} + +// UpdateDNSServer dummy implementation +func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { + return nil +} diff --git a/client/internal/dns/server_nonandroid.go b/client/internal/dns/server_nonandroid.go new file mode 100644 index 000000000..55ba28f01 --- /dev/null +++ b/client/internal/dns/server_nonandroid.go @@ -0,0 +1,465 @@ +//go:build !android + +package dns + +import ( + "context" + "fmt" + "net" + "net/netip" + "runtime" + "sync" + "time" + + "github.com/miekg/dns" + "github.com/mitchellh/hashstructure/v2" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" +) + +const ( + defaultPort = 53 + customPort = 5053 + defaultIP = "127.0.0.1" + customIP = "127.0.0.153" +) + +// DefaultServer dns server object +type DefaultServer struct { + ctx context.Context + ctxCancel context.CancelFunc + upstreamCtxCancel context.CancelFunc + mux sync.Mutex + server *dns.Server + dnsMux *dns.ServeMux + dnsMuxMap registrationMap + localResolver *localResolver + wgInterface *iface.WGIface + hostManager hostManager + updateSerial uint64 + listenerIsRunning bool + runtimePort int + runtimeIP string + previousConfigHash uint64 + currentConfig hostDNSConfig + customAddress *netip.AddrPort +} + +type muxUpdate struct { + domain string + handler dns.Handler +} + +// NewDefaultServer returns a new dns server +func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) { + mux := dns.NewServeMux() + + dnsServer := &dns.Server{ + Net: "udp", + Handler: mux, + UDPSize: 65535, + } + + ctx, stop := context.WithCancel(ctx) + + var addrPort *netip.AddrPort + if customAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if err != nil { + stop() + return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) + } + addrPort = &parsedAddrPort + } + + defaultServer := &DefaultServer{ + ctx: ctx, + ctxCancel: stop, + server: dnsServer, + dnsMux: mux, + dnsMuxMap: make(registrationMap), + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + wgInterface: wgInterface, + runtimePort: defaultPort, + customAddress: addrPort, + } + + hostmanager, err := newHostManager(wgInterface) + if err != nil { + stop() + return nil, err + } + defaultServer.hostManager = hostmanager + return defaultServer, err +} + +// Start runs the listener in a go routine +func (s *DefaultServer) Start() { + if s.customAddress != nil { + s.runtimeIP = s.customAddress.Addr().String() + s.runtimePort = int(s.customAddress.Port()) + } else { + ip, port, err := s.getFirstListenerAvailable() + if err != nil { + log.Error(err) + return + } + s.runtimeIP = ip + s.runtimePort = port + } + + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) + + log.Debugf("starting dns on %s", s.server.Addr) + + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + + err := s.server.ListenAndServe() + if err != nil { + log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) + } + }() +} + +func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { + ips := []string{defaultIP, customIP} + if runtime.GOOS != "darwin" && s.wgInterface != nil { + ips = append([]string{s.wgInterface.Address().IP.String()}, ips...) + } + ports := []int{defaultPort, customPort} + for _, port := range ports { + for _, ip := range ips { + addrString := fmt.Sprintf("%s:%d", ip, port) + udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) + probeListener, err := net.ListenUDP("udp", udpAddr) + if err == nil { + err = probeListener.Close() + if err != nil { + log.Errorf("got an error closing the probe listener, error: %s", err) + } + return ip, port, nil + } + log.Warnf("binding dns on %s is not available, error: %s", addrString, err) + } + } + return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) +} + +func (s *DefaultServer) setListenerStatus(running bool) { + s.listenerIsRunning = running +} + +// Stop stops the server +func (s *DefaultServer) Stop() { + s.mux.Lock() + defer s.mux.Unlock() + s.ctxCancel() + + err := s.hostManager.restoreHostDNS() + if err != nil { + log.Error(err) + } + + err = s.stopListener() + if err != nil { + log.Error(err) + } +} + +func (s *DefaultServer) stopListener() error { + if !s.listenerIsRunning { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := s.server.ShutdownContext(ctx) + if err != nil { + return fmt.Errorf("stopping dns server listener returned an error: %v", err) + } + return nil +} + +// UpdateDNSServer processes an update received from the management service +func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { + select { + case <-s.ctx.Done(): + log.Infof("not updating DNS server as context is closed") + return s.ctx.Err() + default: + if serial < s.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", s.updateSerial-serial) + } + s.mux.Lock() + defer s.mux.Unlock() + + hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Errorf("unable to hash the dns configuration update, got error: %s", err) + } + + if s.previousConfigHash == hash { + log.Debugf("not applying the dns configuration update as there is nothing new") + s.updateSerial = serial + return nil + } + + if err := s.applyConfiguration(update); err != nil { + return err + } + + s.updateSerial = serial + s.previousConfigHash = hash + + return nil + } +} + +func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { + // is the service should be disabled, we stop the listener + // and proceed with a regular update to clean up the handlers and records + if !update.ServiceEnable { + err := s.stopListener() + if err != nil { + log.Error(err) + } + } else if !s.listenerIsRunning { + s.Start() + } + + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + + muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) + + s.updateMux(muxUpdates) + s.updateLocalResolver(localRecords) + s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) + + if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + log.Error(err) + } + + return nil +} + +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { + var muxUpdates []muxUpdate + localRecords := make(map[string]nbdns.SimpleRecord, 0) + + for _, customZone := range customZones { + + if len(customZone.Records) == 0 { + return nil, nil, fmt.Errorf("received an empty list of records") + } + + muxUpdates = append(muxUpdates, muxUpdate{ + domain: customZone.Domain, + handler: s.localResolver, + }) + + for _, record := range customZone.Records { + var class uint16 = dns.ClassINET + if record.Class != nbdns.DefaultClass { + return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + } + key := buildRecordKey(record.Name, class, uint16(record.Type)) + localRecords[key] = record + } + } + return muxUpdates, localRecords, nil +} + +func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { + // clean up the previous upstream resolver + if s.upstreamCtxCancel != nil { + s.upstreamCtxCancel() + } + + var muxUpdates []muxUpdate + for _, nsGroup := range nameServerGroups { + if len(nsGroup.NameServers) == 0 { + log.Warn("received a nameserver group with empty nameserver list") + continue + } + + var ctx context.Context + ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx) + + handler := newUpstreamResolver(ctx) + for _, ns := range nsGroup.NameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) + } + + if len(handler.upstreamServers) == 0 { + log.Errorf("received a nameserver group with an invalid nameserver list") + continue + } + + // when upstream fails to resolve domain several times over all it servers + // it will calls this hook to exclude self from the configuration and + // reapply DNS settings, but it not touch the original configuration and serial number + // because it is temporal deactivation until next try + // + // after some period defined by upstream it trys to reactivate self by calling this hook + // everything we need here is just to re-apply current configuration because it already + // contains this upstream settings (temporal deactivation not removed it) + handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) + + if nsGroup.Primary { + muxUpdates = append(muxUpdates, muxUpdate{ + domain: nbdns.RootZone, + handler: handler, + }) + continue + } + + if len(nsGroup.Domains) == 0 { + return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") + } + + for _, domain := range nsGroup.Domains { + if domain == "" { + return nil, fmt.Errorf("received a nameserver group with an empty domain element") + } + muxUpdates = append(muxUpdates, muxUpdate{ + domain: domain, + handler: handler, + }) + } + } + return muxUpdates, nil +} + +func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { + muxUpdateMap := make(registrationMap) + + for _, update := range muxUpdates { + s.registerMux(update.domain, update.handler) + muxUpdateMap[update.domain] = struct{}{} + } + + for key := range s.dnsMuxMap { + _, found := muxUpdateMap[key] + if !found { + s.deregisterMux(key) + } + } + + s.dnsMuxMap = muxUpdateMap +} + +func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { + for key := range s.localResolver.registeredMap { + _, found := update[key] + if !found { + s.localResolver.deleteRecord(key) + } + } + + updatedMap := make(registrationMap) + for key, record := range update { + err := s.localResolver.registerRecord(record) + if err != nil { + log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err) + } + updatedMap[key] = struct{}{} + } + + s.localResolver.registeredMap = updatedMap +} + +func getNSHostPort(ns nbdns.NameServer) string { + return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) +} + +func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) { + s.dnsMux.Handle(pattern, handler) +} + +func (s *DefaultServer) deregisterMux(pattern string) { + s.dnsMux.HandleRemove(pattern) +} + +// upstreamCallbacks returns two functions, the first one is used to deactivate +// the upstream resolver from the configuration, the second one is used to +// reactivate it. Not allowed to call reactivate before deactivate. +func (s *DefaultServer) upstreamCallbacks( + nsGroup *nbdns.NameServerGroup, + handler dns.Handler, +) (deactivate func(), reactivate func()) { + var removeIndex map[string]int + deactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Info("temporary deactivate nameservers group due timeout") + + removeIndex = make(map[string]int) + for _, domain := range nsGroup.Domains { + removeIndex[domain] = -1 + } + if nsGroup.Primary { + removeIndex[nbdns.RootZone] = -1 + s.currentConfig.routeAll = false + } + + for i, item := range s.currentConfig.domains { + if _, found := removeIndex[item.domain]; found { + s.currentConfig.domains[i].disabled = true + s.deregisterMux(item.domain) + removeIndex[item.domain] = i + } + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("fail to apply nameserver deactivation on the host") + } + } + reactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + for domain, i := range removeIndex { + if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { + continue + } + s.currentConfig.domains[i].disabled = false + s.registerMux(domain, handler) + } + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Debug("reactivate temporary disabled nameserver group") + + if nsGroup.Primary { + s.currentConfig.routeAll = true + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + } + } + return +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 395652733..208007236 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -199,7 +199,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU) + wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index bc90244bf..10d74d931 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -46,6 +46,8 @@ var ErrResetConnection = fmt.Errorf("reset connection") type EngineConfig struct { WgPort int WgIfaceName string + // TunAdapter is option. It is necessary for mobile version. + TunAdapter iface.TunAdapter // WgAddr is a Wireguard local address (Netbird Network IP) WgAddr string @@ -173,7 +175,7 @@ func (e *Engine) Start() error { myPrivateKey := e.config.WgPrivateKey var err error - e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) + e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter) if err != nil { log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error()) return err @@ -614,6 +616,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { if protoDNSConfig == nil { protoDNSConfig = &mgmProto.DNSConfig{} } + err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) if err != nil { log.Errorf("failed to update dns server, err: %v", err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2e57a16c2..f46fcf78e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -207,7 +207,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { WgPrivateKey: key, WgPort: 33100, }, peer.NewRecorder("https://mgm")) - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) + engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -549,7 +549,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, }, peer.NewRecorder("https://mgm")) - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -714,7 +714,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, }, peer.NewRecorder("https://mgm")) - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/login.go b/client/internal/login.go index 677ff1edc..08efd5147 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -2,37 +2,26 @@ package internal import ( "context" + "net/url" + "github.com/google/uuid" - "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/client/system" - mgm "github.com/netbirdio/netbird/management/client" - mgmProto "github.com/netbirdio/netbird/management/proto" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/management/client" + mgmProto "github.com/netbirdio/netbird/management/proto" ) -func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey) +// IsLoginRequired check that the server is support SSO or not +func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) { + mgmClient, err := getMgmClient(ctx, privateKey, mgmURL) if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error()) - return err + return false, err } - - var mgmTlsEnabled bool - if config.ManagementURL.Scheme == "https" { - mgmTlsEnabled = true - } - - log.Debugf("connecting to the Management service %s", config.ManagementURL.String()) - mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) - if err != nil { - log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err) - return err - } - log.Debugf("connected to the Management service %s", config.ManagementURL.String()) defer func() { err = mgmClient.Close() if err != nil { @@ -42,40 +31,84 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string } } }() + log.Debugf("connected to the Management service %s", mgmURL.String()) - serverKey, err := mgmClient.GetServerPublicKey() + pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey)) + if err != nil { + return false, err + } + + _, err = doMgmLogin(ctx, mgmClient, pubSSHKey) + if isLoginNeeded(err) { + return true, nil + } + return false, err +} + +// Login or register the client +func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error { + mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL) if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) return err } + defer func() { + err = mgmClient.Close() + if err != nil { + cStatus, ok := status.FromError(err) + if !ok || ok && cStatus.Code() != codes.Canceled { + log.Warnf("failed to close the Management service client, err: %v", err) + } + } + }() + log.Debugf("connected to the Management service %s", config.ManagementURL.String()) pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) if err != nil { return err } - _, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey) - if err != nil { - log.Errorf("failed logging-in peer on Management Service : %v", err) + + serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey) + if isRegistrationNeeded(err) { + log.Debugf("peer registration required") + _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey) return err } - log.Infof("peer has successfully logged-in to the Management service %s", config.ManagementURL.String()) - return nil + + return err } -// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow. -func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { - sysInfo := system.GetInfo(ctx) - loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey) +func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { + // validate our peer's Wireguard PRIVATE key + myPrivateKey, err := wgtypes.ParseKey(privateKey) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied { - log.Debugf("peer registration required") - return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey) - } else { - return nil, err - } + log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) + return nil, err } - return loginResp, nil + var mgmTlsEnabled bool + if mgmURL.Scheme == "https" { + mgmTlsEnabled = true + } + + log.Debugf("connecting to the Management service %s", mgmURL.String()) + mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled) + if err != nil { + log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err) + return nil, err + } + return mgmClient, err +} + +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) { + serverKey, err := mgmClient.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + sysInfo := system.GetInfo(ctx) + _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey) + return serverKey, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. @@ -98,3 +131,31 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. return loginResp, nil } + +func isLoginNeeded(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied { + return true + } + return false +} + +func isRegistrationNeeded(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + if s.Code() == codes.PermissionDenied { + return true + } + return false +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c5d43b34c..e42e6305d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pion/ice/v2" + "github.com/pion/transport/v2/stdnet" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" @@ -161,7 +162,10 @@ func (conn *Conn) reCreateAgent() error { defer conn.mu.Unlock() failedTimeout := 6 * time.Second - var err error + transportNet, err := stdnet.NewNet() + if err != nil { + log.Warnf("failed to create pion's stdnet: %s", err) + } agentConfig := &ice.AgentConfig{ MulticastDNSMode: ice.MulticastDNSModeDisabled, NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, @@ -172,6 +176,7 @@ func (conn *Conn) reCreateAgent() error { UDPMux: conn.config.UDPMux, UDPMuxSrflx: conn.config.UDPMuxSrflx, NAT1To1IPs: conn.config.NATExternalIPs, + Net: transportNet, } if conn.config.DisableIPv6Discovery { diff --git a/client/internal/peer/listener.go b/client/internal/peer/listener.go new file mode 100644 index 000000000..9324c6773 --- /dev/null +++ b/client/internal/peer/listener.go @@ -0,0 +1,9 @@ +package peer + +// Listener is a callback type about the NetBird network connection state +type Listener interface { + OnConnected() + OnDisconnected() + OnConnecting() + OnPeersListChanged(int) +} diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go new file mode 100644 index 000000000..db1c32e97 --- /dev/null +++ b/client/internal/peer/notifier.go @@ -0,0 +1,124 @@ +package peer + +import ( + "sync" +) + +const ( + stateDisconnected = iota + stateConnected + stateConnecting +) + +type notifier struct { + serverStateLock sync.Mutex + listenersLock sync.Mutex + listeners map[Listener]struct{} + currentServerState bool + currentClientState bool + lastNotification int +} + +func newNotifier() *notifier { + return ¬ifier{ + listeners: make(map[Listener]struct{}), + } +} + +func (n *notifier) addListener(listener Listener) { + n.listenersLock.Lock() + defer n.listenersLock.Unlock() + + n.serverStateLock.Lock() + go n.notifyListener(listener, n.lastNotification) + n.serverStateLock.Unlock() + n.listeners[listener] = struct{}{} +} + +func (n *notifier) removeListener(listener Listener) { + n.listenersLock.Lock() + defer n.listenersLock.Unlock() + delete(n.listeners, listener) +} + +func (n *notifier) updateServerStates(mgmState bool, signalState bool) { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + + var newState bool + if mgmState && signalState { + newState = true + } else { + newState = false + } + + if !n.isServerStateChanged(newState) { + return + } + + n.currentServerState = newState + n.lastNotification = n.calculateState(newState, n.currentClientState) + + go n.notifyAll(n.lastNotification) +} + +func (n *notifier) clientStart() { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + n.currentClientState = true + n.lastNotification = n.calculateState(n.currentServerState, true) + go n.notifyAll(n.lastNotification) +} + +func (n *notifier) clientStop() { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + n.currentClientState = false + n.lastNotification = n.calculateState(n.currentServerState, false) + go n.notifyAll(n.lastNotification) +} + +func (n *notifier) isServerStateChanged(newState bool) bool { + return n.currentServerState != newState +} + +func (n *notifier) notifyAll(state int) { + n.listenersLock.Lock() + defer n.listenersLock.Unlock() + + for l := range n.listeners { + n.notifyListener(l, state) + } +} + +func (n *notifier) notifyListener(l Listener, state int) { + switch state { + case stateDisconnected: + l.OnDisconnected() + case stateConnected: + l.OnConnected() + case stateConnecting: + l.OnConnecting() + } +} + +func (n *notifier) calculateState(serverState bool, clientState bool) int { + if serverState && clientState { + return stateConnected + } + + if !clientState { + return stateDisconnected + } + + return stateConnecting +} + +func (n *notifier) peerListChanged(numOfPeers int) { + n.listenersLock.Lock() + defer n.listenersLock.Unlock() + + for l := range n.listeners { + l.OnPeersListChanged(numOfPeers) + } +} diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go new file mode 100644 index 000000000..f21193e06 --- /dev/null +++ b/client/internal/peer/notifier_test.go @@ -0,0 +1,32 @@ +package peer + +import ( + "testing" +) + +func Test_notifier_serverState(t *testing.T) { + + type scenario struct { + name string + expected bool + mgmState bool + signalState bool + } + scenarios := []scenario{ + {"connected", true, true, true}, + {"mgm down", false, false, true}, + {"signal down", false, true, false}, + {"disconnected", false, false, false}, + } + + for _, tt := range scenarios { + t.Run(tt.name, func(t *testing.T) { + n := newNotifier() + n.updateServerStates(tt.mgmState, tt.signalState) + if n.currentServerState != tt.expected { + t.Errorf("invalid serverstate: %t, expected: %t", n.currentServerState, tt.expected) + } + + }) + } +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 6515cceb2..b0a3f338e 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -58,6 +58,7 @@ type Status struct { offlinePeers []State mgmAddress string signalAddress string + notifier *notifier } // NewRecorder returns a new Status instance @@ -66,6 +67,7 @@ func NewRecorder(mgmAddress string) *Status { peers: make(map[string]State), changeNotify: make(map[string]chan struct{}), offlinePeers: make([]State, 0), + notifier: newNotifier(), mgmAddress: mgmAddress, } } @@ -114,6 +116,7 @@ func (d *Status) RemovePeer(peerPubKey string) error { return nil } + d.notifyPeerListChanged() return errors.New("no peer with to remove") } @@ -148,6 +151,7 @@ func (d *Status) UpdatePeerState(receivedState State) error { d.changeNotify[receivedState.PubKey] = nil } + d.notifyPeerListChanged() return nil } @@ -164,6 +168,7 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { peerState.FQDN = fqdn d.peers[peerPubKey] = peerState + d.notifyPeerListChanged() return nil } @@ -199,6 +204,8 @@ func (d *Status) CleanLocalPeerState() { func (d *Status) MarkManagementDisconnected() { d.mux.Lock() defer d.mux.Unlock() + defer d.onConnectionChanged() + d.managementState = false } @@ -206,7 +213,9 @@ func (d *Status) MarkManagementDisconnected() { func (d *Status) MarkManagementConnected() { d.mux.Lock() defer d.mux.Unlock() - d.managementState = true + defer d.onConnectionChanged() + + d.managementState = true } // UpdateSignalAddress update the address of the signal server @@ -227,13 +236,17 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) { func (d *Status) MarkSignalDisconnected() { d.mux.Lock() defer d.mux.Unlock() - d.signalState = false + defer d.onConnectionChanged() + + d.signalState = false } // MarkSignalConnected sets SignalState to connected func (d *Status) MarkSignalConnected() { d.mux.Lock() defer d.mux.Unlock() + defer d.onConnectionChanged() + d.signalState = true } @@ -262,3 +275,31 @@ func (d *Status) GetFullStatus() FullStatus { return fullStatus } + +// ClientStart will notify all listeners about the new service state +func (d *Status) ClientStart() { + d.notifier.clientStart() +} + +// ClientStop will notify all listeners about the new service state +func (d *Status) ClientStop() { + d.notifier.clientStop() +} + +// AddConnectionListener add a listener to the notifier +func (d *Status) AddConnectionListener(listener Listener) { + d.notifier.addListener(listener) +} + +// RemoveConnectionListener remove a listener from the notifier +func (d *Status) RemoveConnectionListener(listener Listener) { + d.notifier.removeListener(listener) +} + +func (d *Status) onConnectionChanged() { + d.notifier.updateServerStates(d.managementState, d.signalState) +} + +func (d *Status) notifyPeerListChanged() { + d.notifier.peerListChanged(len(d.peers)) +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b52d97e97..3a5ead85a 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -1,190 +1,9 @@ package routemanager -import ( - "context" - "fmt" - "runtime" - "sync" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/version" -) +import "github.com/netbirdio/netbird/route" // Manager is a route manager interface type Manager interface { UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error Stop() } - -// DefaultManager is the default instance of a route manager -type DefaultManager struct { - ctx context.Context - stop context.CancelFunc - mux sync.Mutex - clientNetworks map[string]*clientNetwork - serverRoutes map[string]*route.Route - serverRouter *serverRouter - statusRecorder *peer.Status - wgInterface *iface.WGIface - pubKey string -} - -// NewManager returns a new route manager -func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager { - mCTX, cancel := context.WithCancel(ctx) - return &DefaultManager{ - ctx: mCTX, - stop: cancel, - clientNetworks: make(map[string]*clientNetwork), - serverRoutes: make(map[string]*route.Route), - serverRouter: &serverRouter{ - routes: make(map[string]*route.Route), - netForwardHistoryEnabled: isNetForwardHistoryEnabled(), - firewall: NewFirewall(ctx), - }, - statusRecorder: statusRecorder, - wgInterface: wgInterface, - pubKey: pubKey, - } -} - -// Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop() { - m.stop() - m.serverRouter.firewall.CleanRoutingRules() -} - -func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { - // removing routes that do not exist as per the update from the Management service. - for id, client := range m.clientNetworks { - _, found := networks[id] - if !found { - log.Debugf("stopping client network watcher, %s", id) - client.stop() - delete(m.clientNetworks, id) - } - } - - for id, routes := range networks { - clientNetworkWatcher, found := m.clientNetworks[id] - if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) - m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.peersStateAndUpdateWatcher() - } - update := routesUpdate{ - updateSerial: updateSerial, - routes: routes, - } - - clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) - } -} - -func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error { - serverRoutesToRemove := make([]string, 0) - - if len(routesMap) > 0 { - err := m.serverRouter.firewall.RestoreOrCreateContainers() - if err != nil { - return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) - } - } - - for routeID := range m.serverRoutes { - update, found := routesMap[routeID] - if !found || !update.IsEqual(m.serverRoutes[routeID]) { - serverRoutesToRemove = append(serverRoutesToRemove, routeID) - continue - } - } - - for _, routeID := range serverRoutesToRemove { - oldRoute := m.serverRoutes[routeID] - err := m.removeFromServerNetwork(oldRoute) - if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", - oldRoute.ID, oldRoute.Network, err) - } - delete(m.serverRoutes, routeID) - } - - for id, newRoute := range routesMap { - _, found := m.serverRoutes[id] - if found { - continue - } - - err := m.addToServerNetwork(newRoute) - if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) - continue - } - m.serverRoutes[id] = newRoute - } - - if len(m.serverRoutes) > 0 { - err := enableIPForwarding() - if err != nil { - return err - } - } - - return nil -} - -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { - select { - case <-m.ctx.Done(): - log.Infof("not updating routes as context is closed") - return m.ctx.Err() - default: - m.mux.Lock() - defer m.mux.Unlock() - - newClientRoutesIDMap := make(map[string][]*route.Route) - newServerRoutesMap := make(map[string]*route.Route) - ownNetworkIDs := make(map[string]bool) - - for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) - if newRoute.Peer == m.pubKey { - ownNetworkIDs[networkID] = true - // only linux is supported for now - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } - newServerRoutesMap[newRoute.ID] = newRoute - } - } - - for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) - if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < 7 { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", - version.NetbirdVersion(), newRoute.Network) - continue - } - newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) - } - } - - m.updateClientNetworks(updateSerial, newClientRoutesIDMap) - - err := m.updateServerRoutes(newServerRoutesMap) - if err != nil { - return err - } - - return nil - } -} diff --git a/client/internal/routemanager/manager_android.go b/client/internal/routemanager/manager_android.go new file mode 100644 index 000000000..31cba102c --- /dev/null +++ b/client/internal/routemanager/manager_android.go @@ -0,0 +1,31 @@ +package routemanager + +import ( + "context" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" +) + +// DefaultManager dummy router manager for Android +type DefaultManager struct { + ctx context.Context + serverRouter *serverRouter + wgInterface *iface.WGIface +} + +// NewManager returns a new dummy route manager what doing nothing +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager { + return &DefaultManager{} +} + +// UpdateRoutes ... +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { + return nil +} + +// Stop ... +func (m *DefaultManager) Stop() { + +} diff --git a/client/internal/routemanager/manager_nonandroid.go b/client/internal/routemanager/manager_nonandroid.go new file mode 100644 index 000000000..361eba549 --- /dev/null +++ b/client/internal/routemanager/manager_nonandroid.go @@ -0,0 +1,186 @@ +//go:build !android + +package routemanager + +import ( + "context" + "fmt" + "runtime" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/version" +) + +// DefaultManager is the default instance of a route manager +type DefaultManager struct { + ctx context.Context + stop context.CancelFunc + mux sync.Mutex + clientNetworks map[string]*clientNetwork + serverRoutes map[string]*route.Route + serverRouter *serverRouter + statusRecorder *peer.Status + wgInterface *iface.WGIface + pubKey string +} + +// NewManager returns a new route manager +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager { + mCTX, cancel := context.WithCancel(ctx) + return &DefaultManager{ + ctx: mCTX, + stop: cancel, + clientNetworks: make(map[string]*clientNetwork), + serverRoutes: make(map[string]*route.Route), + serverRouter: &serverRouter{ + routes: make(map[string]*route.Route), + netForwardHistoryEnabled: isNetForwardHistoryEnabled(), + firewall: NewFirewall(ctx), + }, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + pubKey: pubKey, + } +} + +// Stop stops the manager watchers and clean firewall rules +func (m *DefaultManager) Stop() { + m.stop() + m.serverRouter.firewall.CleanRoutingRules() +} + +func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { + // removing routes that do not exist as per the update from the Management service. + for id, client := range m.clientNetworks { + _, found := networks[id] + if !found { + log.Debugf("stopping client network watcher, %s", id) + client.stop() + delete(m.clientNetworks, id) + } + } + + for id, routes := range networks { + clientNetworkWatcher, found := m.clientNetworks[id] + if !found { + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + m.clientNetworks[id] = clientNetworkWatcher + go clientNetworkWatcher.peersStateAndUpdateWatcher() + } + update := routesUpdate{ + updateSerial: updateSerial, + routes: routes, + } + + clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) + } +} + +func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error { + serverRoutesToRemove := make([]string, 0) + + if len(routesMap) > 0 { + err := m.serverRouter.firewall.RestoreOrCreateContainers() + if err != nil { + return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) + } + } + + for routeID := range m.serverRoutes { + update, found := routesMap[routeID] + if !found || !update.IsEqual(m.serverRoutes[routeID]) { + serverRoutesToRemove = append(serverRoutesToRemove, routeID) + continue + } + } + + for _, routeID := range serverRoutesToRemove { + oldRoute := m.serverRoutes[routeID] + err := m.removeFromServerNetwork(oldRoute) + if err != nil { + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + oldRoute.ID, oldRoute.Network, err) + } + delete(m.serverRoutes, routeID) + } + + for id, newRoute := range routesMap { + _, found := m.serverRoutes[id] + if found { + continue + } + + err := m.addToServerNetwork(newRoute) + if err != nil { + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + continue + } + m.serverRoutes[id] = newRoute + } + + if len(m.serverRoutes) > 0 { + err := enableIPForwarding() + if err != nil { + return err + } + } + + return nil +} + +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not updating routes as context is closed") + return m.ctx.Err() + default: + m.mux.Lock() + defer m.mux.Unlock() + + newClientRoutesIDMap := make(map[string][]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + ownNetworkIDs := make(map[string]bool) + + for _, newRoute := range newRoutes { + networkID := route.GetHAUniqueID(newRoute) + if newRoute.Peer == m.pubKey { + ownNetworkIDs[networkID] = true + // only linux is supported for now + if runtime.GOOS != "linux" { + log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + continue + } + newServerRoutesMap[newRoute.ID] = newRoute + } + } + + for _, newRoute := range newRoutes { + networkID := route.GetHAUniqueID(newRoute) + if !ownNetworkIDs[networkID] { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < 7 { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + version.NetbirdVersion(), newRoute.Network) + continue + } + newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) + } + } + + m.updateClientNetworks(updateSerial, newClientRoutesIDMap) + + err := m.updateServerRoutes(newServerRoutesMap) + if err != nil { + return err + } + + return nil + } +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 9d3b5f5ff..47b43b396 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -391,7 +391,7 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 234a4973d..fe0c861da 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -32,7 +32,7 @@ func TestAddRemoveRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/server/server.go b/client/server/server.go index e1adc0239..238b15acc 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -102,7 +102,7 @@ func (s *Server) Start() error { } go func() { - if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil { + if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil { log.Errorf("init connections: %v", err) } }() @@ -394,7 +394,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } go func() { - if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil { + if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil { log.Errorf("run client connection: %v", err) return } diff --git a/client/system/info.go b/client/system/info.go index d948a185f..15b26d0e2 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -9,6 +9,9 @@ import ( "github.com/netbirdio/netbird/version" ) +// DeviceNameCtxKey context key for device name +const DeviceNameCtxKey = "deviceName" + // Info is an object that contains machine information // Most of the code is taken from https://github.com/matishsiao/goInfo type Info struct { diff --git a/client/system/info_android.go b/client/system/info_android.go new file mode 100644 index 000000000..65fb409f6 --- /dev/null +++ b/client/system/info_android.go @@ -0,0 +1,63 @@ +//go:build android +// +build android + +package system + +import ( + "bytes" + "context" + "os/exec" + "runtime" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" +) + +// GetInfo retrieves and parses the system information +func GetInfo(ctx context.Context) *Info { + kernel := "android" + osInfo := uname() + if len(osInfo) == 2 { + kernel = osInfo[1] + } + + gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} + gio.Hostname = extractDeviceName(ctx) + gio.WiretrusteeVersion = version.NetbirdVersion() + gio.UIVersion = extractUserAgent(ctx) + + return gio +} + +func extractDeviceName(ctx context.Context) string { + v, ok := ctx.Value(DeviceNameCtxKey).(string) + if !ok { + return "" + } + return v +} + +func uname() []string { + res := run("/system/bin/uname", "-a") + return strings.Split(res, " ") +} + +func osVersion() string { + return run("/system/bin/getprop", "ro.build.version.release") +} + +func run(name string, arg ...string) string { + cmd := exec.Command(name, arg...) + cmd.Stdin = strings.NewReader("some") + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + err := cmd.Run() + if err != nil { + log.Errorf("getInfo: %s", err) + } + return out.String() +} diff --git a/client/system/info_linux.go b/client/system/info_linux.go index e0546cf88..e3215cd09 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -1,3 +1,6 @@ +//go:build !android +// +build !android + package system import ( diff --git a/formatter/formatter.go b/formatter/formatter.go index 591a9193b..a37c67914 100644 --- a/formatter/formatter.go +++ b/formatter/formatter.go @@ -10,15 +10,15 @@ import ( // TextFormatter formats logs into text with included source code's path type TextFormatter struct { - TimestampFormat string - LevelDesc []string + timestampFormat string + levelDesc []string } // NewTextFormatter create new MyTextFormatter instance func NewTextFormatter() *TextFormatter { return &TextFormatter{ - LevelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}, - TimestampFormat: time.RFC3339, // or RFC3339 + levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}, + timestampFormat: time.RFC3339, // or RFC3339 } } @@ -39,13 +39,13 @@ func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) { level := f.parseLevel(entry.Level) - return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.TimestampFormat), level, fields, entry.Data["source"], entry.Message)), nil + return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.timestampFormat), level, fields, entry.Data["source"], entry.Message)), nil } func (f *TextFormatter) parseLevel(level logrus.Level) string { - if len(f.LevelDesc) < int(level) { + if len(f.levelDesc) < int(level) { return "" } - return f.LevelDesc[level] + return f.levelDesc[level] } diff --git a/formatter/logcat.go b/formatter/logcat.go new file mode 100644 index 000000000..e8f606229 --- /dev/null +++ b/formatter/logcat.go @@ -0,0 +1,48 @@ +package formatter + +import ( + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +// LogcatFormatter formats logs into text what is fit for logcat +type LogcatFormatter struct { + levelDesc []string +} + +// NewLogcatFormatter create new LogcatFormatter instance +func NewLogcatFormatter() *LogcatFormatter { + return &LogcatFormatter{ + levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}, + } +} + +// Format renders a single log entry +func (f *LogcatFormatter) Format(entry *logrus.Entry) ([]byte, error) { + var fields string + keys := make([]string, 0, len(entry.Data)) + for k, v := range entry.Data { + if k == "source" { + continue + } + keys = append(keys, fmt.Sprintf("%s: %v", k, v)) + } + + if len(keys) > 0 { + fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", ")) + } + + level := f.parseLevel(entry.Level) + + return []byte(fmt.Sprintf("[%s] %s%s %s\n", level, fields, entry.Data["source"], entry.Message)), nil +} + +func (f *LogcatFormatter) parseLevel(level logrus.Level) string { + if len(f.levelDesc) < int(level) { + return "" + } + + return f.levelDesc[level] +} diff --git a/formatter/logcat_test.go b/formatter/logcat_test.go new file mode 100644 index 000000000..45ba5bc46 --- /dev/null +++ b/formatter/logcat_test.go @@ -0,0 +1,28 @@ +package formatter + +import ( + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +func TestLogcatMessageFormat(t *testing.T) { + + someEntry := &logrus.Entry{ + Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"}, + Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC), + Level: 3, + Message: "Some Message", + } + + formatter := NewLogcatFormatter() + result, _ := formatter.Format(someEntry) + + expectedString := "[WARN] [att1: 1, att2: 2] some/fancy/path.go:46 Some Message\n" + expectedStringVariant := "[WARN] [att2: 2, att1: 1] some/fancy/path.go:46 Some Message\n" + parsedString := string(result) + if parsedString != expectedString && parsedString != expectedStringVariant { + t.Errorf("The log messages don't match. Expected: '%s', got: '%s'", expectedString, parsedString) + } +} diff --git a/formatter/set.go b/formatter/set.go index 2f8e0331f..cceeef860 100644 --- a/formatter/set.go +++ b/formatter/set.go @@ -2,9 +2,16 @@ package formatter import "github.com/sirupsen/logrus" -// SetTextFormatter set the formatter for given logger. +// SetTextFormatter set the text formatter for given logger. func SetTextFormatter(logger *logrus.Logger) { logger.Formatter = NewTextFormatter() logger.ReportCaller = true logger.AddHook(NewContextHook()) } + +// SetLogcatFormatter set the logcat formatter for given logger. +func SetLogcatFormatter(logger *logrus.Logger) { + logger.Formatter = NewLogcatFormatter() + logger.ReportCaller = true + logger.AddHook(NewContextHook()) +} diff --git a/go.mod b/go.mod index e65b4e302..d6518467f 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/open-policy-agent/opa v0.49.0 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pion/transport/v2 v2.0.2 github.com/prometheus/client_golang v1.14.0 github.com/rs/xid v1.3.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -105,7 +106,6 @@ require ( github.com/pion/mdns v0.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/stun v0.4.0 // indirect - github.com/pion/transport/v2 v2.0.2 // indirect github.com/pion/turn/v2 v2.1.0 // indirect github.com/pion/udp/v2 v2.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/iface/iface.go b/iface/iface.go index 4e88f57e7..131558e77 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -1,13 +1,11 @@ package iface import ( - "fmt" "net" "sync" "time" log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -16,46 +14,30 @@ const ( DefaultWgPort = 51820 ) -// NetInterface represents a generic network tunnel interface -type NetInterface interface { - Close() error -} - // WGIface represents a interface instance type WGIface struct { - name string - address WGAddress - mtu int - netInterface NetInterface - mu sync.Mutex + tun *tunDevice + configurer wGConfigurer + mu sync.Mutex } -// NewWGIFace Creates a new Wireguard interface instance -func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) { - wgIface := &WGIface{ - name: iface, - mtu: mtu, - mu: sync.Mutex{}, - } - - wgAddress, err := parseWGAddress(address) - if err != nil { - return wgIface, err - } - - wgIface.address = wgAddress - - return wgIface, nil +// Create creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) Create() error { + w.mu.Lock() + defer w.mu.Unlock() + log.Debugf("create Wireguard interface %s", w.tun.DeviceName()) + return w.tun.Create() } // Name returns the interface name func (w *WGIface) Name() string { - return w.name + return w.tun.DeviceName() } // Address returns the interface address func (w *WGIface) Address() WGAddress { - return w.address + return w.tun.WgAddress() } // Configure configures a Wireguard interface @@ -63,27 +45,8 @@ func (w *WGIface) Address() WGAddress { func (w *WGIface) Configure(privateKey string, port int) error { w.mu.Lock() defer w.mu.Unlock() - - log.Debugf("configuring Wireguard interface %s", w.name) - - log.Debugf("adding Wireguard private key") - key, err := wgtypes.ParseKey(privateKey) - if err != nil { - return err - } - fwmark := 0 - config := wgtypes.Config{ - PrivateKey: &key, - ReplacePeers: true, - FirewallMark: &fwmark, - ListenPort: &port, - } - - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, w.name, port) - } - return nil + log.Debugf("configuring Wireguard interface %s", w.tun.DeviceName()) + return w.configurer.configureInterface(privateKey, port) } // UpdateAddr updates address of the interface @@ -96,8 +59,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { return err } - w.address = addr - return w.assignAddr() + return w.tun.UpdateAddr(addr) } // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist @@ -106,119 +68,8 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D w.mu.Lock() defer w.mu.Unlock() - log.Debugf("updating interface %s peer %s: endpoint %s ", w.name, peerKey, endpoint) - - //parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{*ipNet}, - PersistentKeepaliveInterval: &keepAlive, - PresharedKey: preSharedKey, - Endpoint: endpoint, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, w.name, allowedIps, endpoint.String()) - } - return nil -} - -// AddAllowedIP adds a prefix to the allowed IPs list of peer -func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP) - - _, ipNet, err := net.ParseCIDR(allowedIP) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - UpdateOnly: true, - ReplaceAllowedIPs: false, - AllowedIPs: []net.IPNet{*ipNet}, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP) - } - return nil -} - -// RemoveAllowedIP removes a prefix from the allowed IPs list of peer -func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP) - - _, ipNet, err := net.ParseCIDR(allowedIP) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - - existingPeer, err := getPeer(w.name, peerKey) - if err != nil { - return err - } - - newAllowedIPs := existingPeer.AllowedIPs - - for i, existingAllowedIP := range existingPeer.AllowedIPs { - if existingAllowedIP.String() == ipNet.String() { - newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) - break - } - } - - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - UpdateOnly: true, - ReplaceAllowedIPs: true, - AllowedIPs: newAllowedIPs, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP) - } - return nil + log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint) + return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -226,66 +77,31 @@ func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() defer w.mu.Unlock() - log.Debugf("Removing peer %s from interface %s ", peerKey, w.name) - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - Remove: true, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, w.name) - } - return nil + log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) + return w.configurer.removePeer(peerKey) } -func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { - wg, err := wgctrl.New() - if err != nil { - return wgtypes.Peer{}, err - } - defer func() { - err = wg.Close() - if err != nil { - log.Errorf("got error while closing wgctl: %v", err) - } - }() +// AddAllowedIP adds a prefix to the allowed IPs list of peer +func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() - wgDevice, err := wg.Device(ifaceName) - if err != nil { - return wgtypes.Peer{}, err - } - for _, peer := range wgDevice.Peers { - if peer.PublicKey.String() == peerPubKey { - return peer, nil - } - } - return wgtypes.Peer{}, fmt.Errorf("peer not found") + log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + return w.configurer.addAllowedIP(peerKey, allowedIP) } -// configureDevice configures the wireguard device -func (w *WGIface) configureDevice(config wgtypes.Config) error { - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() +// RemoveAllowedIP removes a prefix from the allowed IPs list of peer +func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() - // validate if device with name exists - _, err = wg.Device(w.name) - if err != nil { - return err - } - log.Debugf("got Wireguard device %s", w.name) - - return wg.ConfigureDevice(w.name, config) + log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + return w.configurer.removeAllowedIP(peerKey, allowedIP) +} + +// Close closes the tunnel interface +func (w *WGIface) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.Close() } diff --git a/iface/iface_android.go b/iface/iface_android.go new file mode 100644 index 000000000..f8d3be5c0 --- /dev/null +++ b/iface/iface_android.go @@ -0,0 +1,22 @@ +package iface + +import "sync" + +// NewWGIFace Creates a new Wireguard interface instance +func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { + wgIface := &WGIface{ + mu: sync.Mutex{}, + } + + wgAddress, err := parseWGAddress(address) + if err != nil { + return wgIface, err + } + + tun := newTunDevice(wgAddress, mtu, tunAdapter) + wgIface.tun = tun + + wgIface.configurer = newWGConfigurer(tun) + + return wgIface, nil +} diff --git a/iface/iface_nonandroid.go b/iface/iface_nonandroid.go new file mode 100644 index 000000000..13bbe184b --- /dev/null +++ b/iface/iface_nonandroid.go @@ -0,0 +1,22 @@ +//go:build !android + +package iface + +import "sync" + +// NewWGIFace Creates a new Wireguard interface instance +func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { + wgIface := &WGIface{ + mu: sync.Mutex{}, + } + + wgAddress, err := parseWGAddress(address) + if err != nil { + return wgIface, err + } + + wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu) + + wgIface.configurer = newWGConfigurer(ifaceName) + return wgIface, nil +} diff --git a/iface/iface_test.go b/iface/iface_test.go index c34e2959c..ec1ea114a 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -32,7 +32,7 @@ func init() { func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" - iface, err := NewWGIFace(ifaceName, addr, DefaultMTU) + iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -92,7 +92,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) { func Test_CreateInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) wgIP := "10.99.99.1/32" - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -121,7 +121,7 @@ func Test_CreateInterface(t *testing.T) { func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -149,7 +149,7 @@ func Test_Close(t *testing.T) { func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -196,7 +196,7 @@ func Test_ConfigureInterface(t *testing.T) { func Test_UpdatePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.9/30" - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -228,7 +228,7 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - peer, err := getPeer(ifaceName, peerPubKey) + peer, err := iface.configurer.getPeer(ifaceName, peerPubKey) if err != nil { t.Fatal(err) } @@ -255,7 +255,7 @@ func Test_UpdatePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.13/30" - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -288,7 +288,7 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = getPeer(ifaceName, peerPubKey) + _, err = iface.configurer.getPeer(ifaceName, peerPubKey) if err.Error() != "peer not found" { t.Fatal(err) } @@ -305,7 +305,7 @@ func Test_ConnectPeers(t *testing.T) { keepAlive := 1 * time.Second - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU) + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -322,7 +322,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU) + iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil) if err != nil { t.Fatal(err) } @@ -375,7 +375,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatalf("waiting for peer handshake timeout after %s", timeout.String()) default: } - peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String()) + peer, gpErr := iface1.configurer.getPeer(peer1ifaceName, peer2Key.PublicKey().String()) if gpErr != nil { t.Fatal(gpErr) } diff --git a/iface/iface_windows.go b/iface/iface_windows.go index 2fefe3402..a67df296c 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -1,69 +1,6 @@ package iface -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/driver" -) - -// Create Creates a new Wireguard interface, sets a given IP and brings it up. -func (w *WGIface) Create() error { - w.mu.Lock() - defer w.mu.Unlock() - - WintunStaticRequestedGUID, _ := windows.GenerateGUID() - adapter, err := driver.CreateAdapter(w.name, "WireGuard", &WintunStaticRequestedGUID) - if err != nil { - err = fmt.Errorf("error creating adapter: %w", err) - return err - } - w.netInterface = adapter - err = adapter.SetAdapterState(driver.AdapterStateUp) - if err != nil { - return err - } - state, _ := adapter.LUID().GUID() - log.Debugln("device guid: ", state.String()) - return w.assignAddr() -} - -// GetInterfaceGUIDString returns an interface GUID string +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - if w.netInterface == nil { - return "", fmt.Errorf("interface has not been initialized yet") - } - windowsDevice := w.netInterface.(*driver.Adapter) - luid := windowsDevice.LUID() - guid, err := luid.GUID() - if err != nil { - return "", err - } - return guid.String(), nil -} - -// Close closes the tunnel interface -func (w *WGIface) Close() error { - w.mu.Lock() - defer w.mu.Unlock() - if w.netInterface == nil { - return nil - } - - return w.netInterface.Close() -} - -// assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (w *WGIface) assignAddr() error { - luid := w.netInterface.(*driver.Adapter).LUID() - - log.Debugf("adding address %s to interface: %s", w.address.IP, w.name) - err := luid.SetIPAddresses([]net.IPNet{{w.address.IP, w.address.Network.Mask}}) - if err != nil { - return err - } - - return nil + return w.tun.getInterfaceGUIDString() } diff --git a/iface/ipc_parser_android.go b/iface/ipc_parser_android.go new file mode 100644 index 000000000..ef757a638 --- /dev/null +++ b/iface/ipc_parser_android.go @@ -0,0 +1,60 @@ +package iface + +import ( + "encoding/hex" + "fmt" + "strings" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func toWgUserspaceString(wgCfg wgtypes.Config) string { + var sb strings.Builder + if wgCfg.PrivateKey != nil { + hexKey := hex.EncodeToString(wgCfg.PrivateKey[:]) + sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey)) + } + + if wgCfg.ListenPort != nil { + sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort)) + } + + if wgCfg.ReplacePeers { + sb.WriteString("replace_peers=true\n") + } + + if wgCfg.FirewallMark != nil { + sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark)) + } + + for _, p := range wgCfg.Peers { + hexKey := hex.EncodeToString(p.PublicKey[:]) + sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey)) + + if p.PresharedKey != nil { + preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) + sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey)) + } + + if p.Remove { + sb.WriteString("remove=true") + } + + if p.ReplaceAllowedIPs { + sb.WriteString("replace_allowed_ips=true\n") + } + + for _, aip := range p.AllowedIPs { + sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) + } + + if p.Endpoint != nil { + sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String())) + } + + if p.PersistentKeepaliveInterval != nil { + sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))) + } + } + return sb.String() +} diff --git a/iface/module.go b/iface/module.go index 08aa4c5e3..31635fa65 100644 --- a/iface/module.go +++ b/iface/module.go @@ -1,5 +1,5 @@ -//go:build !linux -// +build !linux +//go:build !linux || android +// +build !linux android package iface diff --git a/iface/module_linux.go b/iface/module_linux.go index 2001eeb95..af47b6ef5 100644 --- a/iface/module_linux.go +++ b/iface/module_linux.go @@ -1,3 +1,5 @@ +//go:build linux && !android + // Package iface provides wireguard network interface creation and management package iface diff --git a/iface/ifacename.go b/iface/name.go similarity index 100% rename from iface/ifacename.go rename to iface/name.go diff --git a/iface/ifacename_darwin.go b/iface/name_darwin.go similarity index 100% rename from iface/ifacename_darwin.go rename to iface/name_darwin.go diff --git a/iface/tun.go b/iface/tun.go new file mode 100644 index 000000000..f81222cdb --- /dev/null +++ b/iface/tun.go @@ -0,0 +1,6 @@ +package iface + +// NetInterface represents a generic network tunnel interface +type NetInterface interface { + Close() error +} diff --git a/iface/tun_adapter.go b/iface/tun_adapter.go new file mode 100644 index 000000000..d37302387 --- /dev/null +++ b/iface/tun_adapter.go @@ -0,0 +1,7 @@ +package iface + +// TunAdapter is an interface for create tun device from externel service +type TunAdapter interface { + ConfigureInterface(address string, mtu int) (int, error) + UpdateAddr(address string) error +} diff --git a/iface/tun_android.go b/iface/tun_android.go new file mode 100644 index 000000000..da258e8ec --- /dev/null +++ b/iface/tun_android.go @@ -0,0 +1,112 @@ +package iface + +import ( + "net" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +type tunDevice struct { + address WGAddress + mtu int + tunAdapter TunAdapter + + fd int + name string + device *device.Device + uapi net.Listener +} + +func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice { + return &tunDevice{ + address: address, + mtu: mtu, + tunAdapter: tunAdapter, + } +} + +func (t *tunDevice) Create() error { + var err error + t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu) + if err != nil { + log.Errorf("failed to create Android interface: %s", err) + return err + } + + tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(t.fd) + if err != nil { + unix.Close(t.fd) + return err + } + t.name = name + + log.Debugf("attaching to interface %v", name) + t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) + t.device.DisableSomeRoamingForBrokenMobileSemantics() + + log.Debugf("create uapi") + tunSock, err := ipc.UAPIOpen(name) + if err != nil { + return err + } + + t.uapi, err = ipc.UAPIListen(name, tunSock) + if err != nil { + tunSock.Close() + unix.Close(t.fd) + return err + } + + go func() { + for { + uapiConn, err := t.uapi.Accept() + if err != nil { + return + } + go t.device.IpcHandle(uapiConn) + } + }() + + err = t.device.Up() + if err != nil { + tunSock.Close() + t.device.Close() + return err + } + log.Debugf("device is ready to use: %s", name) + return nil +} + +func (t *tunDevice) Device() *device.Device { + return t.device +} + +func (t *tunDevice) DeviceName() string { + return t.name +} + +func (t *tunDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunDevice) UpdateAddr(addr WGAddress) error { + // todo implement + return nil +} + +func (t *tunDevice) Close() (err error) { + if t.uapi != nil { + err = t.uapi.Close() + } + + if t.device != nil { + t.device.Close() + } + + return +} diff --git a/iface/iface_darwin.go b/iface/tun_darwin.go similarity index 56% rename from iface/iface_darwin.go rename to iface/tun_darwin.go index fd1b6334a..4cf3712bd 100644 --- a/iface/iface_darwin.go +++ b/iface/tun_darwin.go @@ -6,23 +6,25 @@ import ( log "github.com/sirupsen/logrus" ) -// Create Creates a new Wireguard interface, sets a given IP and brings it up. -func (w *WGIface) Create() error { - w.mu.Lock() - defer w.mu.Unlock() +func (c *tunDevice) Create() error { + var err error + c.netInterface, err = c.createWithUserspace() + if err != nil { + return err + } - return w.createWithUserspace() + return c.assignAddr() } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (w *WGIface) assignAddr() error { - cmd := exec.Command("ifconfig", w.name, "inet", w.address.IP.String(), w.address.IP.String()) +func (c *tunDevice) assignAddr() error { + cmd := exec.Command("ifconfig", c.name, "inet", c.address.IP.String(), c.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out) return err } - routeCmd := exec.Command("route", "add", "-net", w.address.Network.String(), "-interface", w.name) + routeCmd := exec.Command("route", "add", "-net", c.address.Network.String(), "-interface", c.name) if out, err := routeCmd.CombinedOutput(); err != nil { log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out) return err diff --git a/iface/iface_linux.go b/iface/tun_linux.go similarity index 60% rename from iface/iface_linux.go rename to iface/tun_linux.go index 042ed5bd6..9bc7da754 100644 --- a/iface/iface_linux.go +++ b/iface/tun_linux.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package iface import ( @@ -8,32 +10,34 @@ import ( "github.com/vishvananda/netlink" ) -// Create creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) Create() error { - w.mu.Lock() - defer w.mu.Unlock() - +func (c *tunDevice) Create() error { if WireguardModuleIsLoaded() { log.Info("using kernel WireGuard") - return w.createWithKernel() - } else { - if !tunModuleIsLoaded() { - return fmt.Errorf("couldn't check or load tun module") - } - log.Info("using userspace WireGuard") - return w.createWithUserspace() + return c.createWithKernel() } + + if !tunModuleIsLoaded() { + return fmt.Errorf("couldn't check or load tun module") + } + log.Info("using userspace WireGuard") + var err error + c.netInterface, err = c.createWithUserspace() + if err != nil { + return err + } + + return c.assignAddr() + } // createWithKernel Creates a new Wireguard interface using kernel Wireguard module. // Works for Linux and offers much better network performance -func (w *WGIface) createWithKernel() error { +func (c *tunDevice) createWithKernel() error { - link := newWGLink(w.name) + link := newWGLink(c.name) // check if interface exists - l, err := netlink.LinkByName(w.name) + l, err := netlink.LinkByName(c.name) if err != nil { switch err.(type) { case netlink.LinkNotFoundError: @@ -51,33 +55,33 @@ func (w *WGIface) createWithKernel() error { } } - log.Debugf("adding device: %s", w.name) + log.Debugf("adding device: %s", c.name) err = netlink.LinkAdd(link) if os.IsExist(err) { - log.Infof("interface %s already exists. Will reuse.", w.name) + log.Infof("interface %s already exists. Will reuse.", c.name) } else if err != nil { return err } - w.netInterface = link + c.netInterface = link - err = w.assignAddr() + err = c.assignAddr() if err != nil { return err } // todo do a discovery - log.Debugf("setting MTU: %d interface: %s", w.mtu, w.name) - err = netlink.LinkSetMTU(link, w.mtu) + log.Debugf("setting MTU: %d interface: %s", c.mtu, c.name) + err = netlink.LinkSetMTU(link, c.mtu) if err != nil { - log.Errorf("error setting MTU on interface: %s", w.name) + log.Errorf("error setting MTU on interface: %s", c.name) return err } - log.Debugf("bringing up interface: %s", w.name) + log.Debugf("bringing up interface: %s", c.name) err = netlink.LinkSetUp(link) if err != nil { - log.Errorf("error bringing up interface: %s", w.name) + log.Errorf("error bringing up interface: %s", c.name) return err } @@ -85,8 +89,8 @@ func (w *WGIface) createWithKernel() error { } // assignAddr Adds IP address to the tunnel interface -func (w *WGIface) assignAddr() error { - link := newWGLink(w.name) +func (c *tunDevice) assignAddr() error { + link := newWGLink(c.name) //delete existing addresses list, err := netlink.AddrList(link, 0) @@ -102,11 +106,11 @@ func (w *WGIface) assignAddr() error { } } - log.Debugf("adding address %s to interface: %s", w.address.String(), w.name) - addr, _ := netlink.ParseAddr(w.address.String()) + log.Debugf("adding address %s to interface: %s", c.address.String(), c.name) + addr, _ := netlink.ParseAddr(c.address.String()) err = netlink.AddrAdd(link, addr) if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", w.name, w.address.String()) + log.Infof("interface %s already has the address: %s", c.name, c.address.String()) } else if err != nil { return err } diff --git a/iface/iface_unix.go b/iface/tun_unix.go similarity index 53% rename from iface/iface_unix.go rename to iface/tun_unix.go index be09afa9e..2f38b5523 100644 --- a/iface/iface_unix.go +++ b/iface/tun_unix.go @@ -1,5 +1,4 @@ -//go:build linux || darwin -// +build linux darwin +//go:build (linux || darwin) && !android package iface @@ -14,24 +13,44 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only -func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return "", nil +type tunDevice struct { + name string + address WGAddress + mtu int + netInterface NetInterface } -// Close closes the tunnel interface -func (w *WGIface) Close() error { - w.mu.Lock() - defer w.mu.Unlock() - if w.netInterface == nil { +func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { + return &tunDevice{ + name: name, + address: address, + mtu: mtu, + } +} + +func (c *tunDevice) UpdateAddr(address WGAddress) error { + c.address = address + return c.assignAddr() +} + +func (c *tunDevice) WgAddress() WGAddress { + return c.address +} + +func (c *tunDevice) DeviceName() string { + return c.name +} + +func (c *tunDevice) Close() error { + if c.netInterface == nil { return nil } - err := w.netInterface.Close() + err := c.netInterface.Close() if err != nil { return err } - sockPath := "/var/run/wireguard/" + w.name + ".sock" + sockPath := "/var/run/wireguard/" + c.name + ".sock" if _, statErr := os.Stat(sockPath); statErr == nil { statErr = os.Remove(sockPath) if statErr != nil { @@ -43,24 +62,23 @@ func (w *WGIface) Close() error { } // createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation -func (w *WGIface) createWithUserspace() error { - - tunIface, err := tun.CreateTUN(w.name, w.mtu) +func (c *tunDevice) createWithUserspace() (NetInterface, error) { + tunIface, err := tun.CreateTUN(c.name, c.mtu) if err != nil { - return err + return nil, err } - w.netInterface = tunIface - // We need to create a wireguard-go device and listen to configuration requests tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) err = tunDevice.Up() if err != nil { - return err + return tunIface, err } - uapi, err := getUAPI(w.name) + + // todo: after this line in case of error close the tunSock + uapi, err := c.getUAPI(c.name) if err != nil { - return err + return tunIface, err } go func() { @@ -75,16 +93,11 @@ func (w *WGIface) createWithUserspace() error { }() log.Debugln("UAPI listener started") - - err = w.assignAddr() - if err != nil { - return err - } - return nil + return tunIface, nil } // getUAPI returns a Listener -func getUAPI(iface string) (net.Listener, error) { +func (c *tunDevice) getUAPI(iface string) (net.Listener, error) { tunSock, err := ipc.UAPIOpen(iface) if err != nil { return nil, err diff --git a/iface/tun_windows.go b/iface/tun_windows.go new file mode 100644 index 000000000..d25b6fc9c --- /dev/null +++ b/iface/tun_windows.go @@ -0,0 +1,93 @@ +package iface + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/driver" +) + +type tunDevice struct { + name string + address WGAddress + netInterface NetInterface +} + +func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { + return &tunDevice{name: name, address: address} +} + +func (c *tunDevice) Create() error { + var err error + c.netInterface, err = c.createAdapter() + if err != nil { + return err + } + + return c.assignAddr() +} + +func (c *tunDevice) UpdateAddr(address WGAddress) error { + c.address = address + return c.assignAddr() +} + +func (c *tunDevice) WgAddress() WGAddress { + return c.address +} + +func (c *tunDevice) DeviceName() string { + return c.name +} + +func (c *tunDevice) Close() error { + if c.netInterface == nil { + return nil + } + + return c.netInterface.Close() +} + +func (c *tunDevice) getInterfaceGUIDString() (string, error) { + if c.netInterface == nil { + return "", fmt.Errorf("interface has not been initialized yet") + } + windowsDevice := c.netInterface.(*driver.Adapter) + luid := windowsDevice.LUID() + guid, err := luid.GUID() + if err != nil { + return "", err + } + return guid.String(), nil +} + +func (c *tunDevice) createAdapter() (NetInterface, error) { + WintunStaticRequestedGUID, _ := windows.GenerateGUID() + adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID) + if err != nil { + err = fmt.Errorf("error creating adapter: %w", err) + return nil, err + } + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + return adapter, err + } + state, _ := adapter.LUID().GUID() + log.Debugln("device guid: ", state.String()) + return adapter, nil +} + +// assignAddr Adds IP address to the tunnel interface and network route based on the range provided +func (c *tunDevice) assignAddr() error { + luid := c.netInterface.(*driver.Adapter).LUID() + + log.Debugf("adding address %s to interface: %s", c.address.IP, c.name) + err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}}) + if err != nil { + return err + } + + return nil +} diff --git a/iface/wg_configurer_android.go b/iface/wg_configurer_android.go new file mode 100644 index 000000000..9328467a6 --- /dev/null +++ b/iface/wg_configurer_android.go @@ -0,0 +1,114 @@ +package iface + +import ( + "errors" + "net" + "time" + + log "github.com/sirupsen/logrus" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + errFuncNotImplemented = errors.New("function not implemented") +) + +type wGConfigurer struct { + tunDevice *tunDevice +} + +func newWGConfigurer(tunDevice *tunDevice) wGConfigurer { + return wGConfigurer{ + tunDevice: tunDevice, + } +} + +func (c *wGConfigurer) configureInterface(privateKey string, port int) error { + log.Debugf("adding Wireguard private key") + key, err := wgtypes.ParseKey(privateKey) + if err != nil { + return err + } + fwmark := 0 + config := wgtypes.Config{ + PrivateKey: &key, + ReplacePeers: true, + FirewallMark: &fwmark, + ListenPort: &port, + } + + return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) +} + +func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + //parse allowed ips + _, ipNet, err := net.ParseCIDR(allowedIps) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{*ipNet}, + PersistentKeepaliveInterval: &keepAlive, + PresharedKey: preSharedKey, + Endpoint: endpoint, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) +} + +func (c *wGConfigurer) removePeer(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) +} + +func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) +} + +func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { + return errFuncNotImplemented +} diff --git a/iface/wg_configurer_nonandroid.go b/iface/wg_configurer_nonandroid.go new file mode 100644 index 000000000..5a2a70ea3 --- /dev/null +++ b/iface/wg_configurer_nonandroid.go @@ -0,0 +1,208 @@ +//go:build !android + +package iface + +import ( + "fmt" + "net" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type wGConfigurer struct { + deviceName string +} + +func newWGConfigurer(deviceName string) wGConfigurer { + return wGConfigurer{ + deviceName: deviceName, + } +} + +func (c *wGConfigurer) configureInterface(privateKey string, port int) error { + log.Debugf("adding Wireguard private key") + key, err := wgtypes.ParseKey(privateKey) + if err != nil { + return err + } + fwmark := 0 + config := wgtypes.Config{ + PrivateKey: &key, + ReplacePeers: true, + FirewallMark: &fwmark, + ListenPort: &port, + } + + err = c.configure(config) + if err != nil { + return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port) + } + return nil +} + +func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + //parse allowed ips + _, ipNet, err := net.ParseCIDR(allowedIps) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{*ipNet}, + PersistentKeepaliveInterval: &keepAlive, + PresharedKey: preSharedKey, + Endpoint: endpoint, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = c.configure(config) + if err != nil { + return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String()) + } + return nil +} + +func (c *wGConfigurer) removePeer(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = c.configure(config) + if err != nil { + return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName) + } + return nil +} + +func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = c.configure(config) + if err != nil { + return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP) + } + return nil +} + +func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + existingPeer, err := c.getPeer(c.deviceName, peerKey) + if err != nil { + return err + } + + newAllowedIPs := existingPeer.AllowedIPs + + for i, existingAllowedIP := range existingPeer.AllowedIPs { + if existingAllowedIP.String() == ipNet.String() { + newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) + break + } + } + + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: true, + AllowedIPs: newAllowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = c.configure(config) + if err != nil { + return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP) + } + return nil +} + +func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { + wg, err := wgctrl.New() + if err != nil { + return wgtypes.Peer{}, err + } + defer func() { + err = wg.Close() + if err != nil { + log.Errorf("got error while closing wgctl: %v", err) + } + }() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + return wgtypes.Peer{}, err + } + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer not found") +} + +func (c *wGConfigurer) configure(config wgtypes.Config) error { + wg, err := wgctrl.New() + if err != nil { + return err + } + defer wg.Close() + + // validate if device with name exists + _, err = wg.Device(c.deviceName) + if err != nil { + return err + } + log.Debugf("got Wireguard device %s", c.deviceName) + + return wg.ConfigureDevice(c.deviceName, config) +}