mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Mobile (#735)
Initial modification to support mobile client Export necessary interfaces for Android framework
This commit is contained in:
parent
747797271e
commit
891ba277b1
121
client/android/client.go
Normal file
121
client/android/client.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
type ConnectionListener interface {
|
||||||
|
peer.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
// TunAdapter export internal TunAdapter for mobile
|
||||||
|
type TunAdapter interface {
|
||||||
|
iface.TunAdapter
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client struct manage the life circle of background service
|
||||||
|
type Client struct {
|
||||||
|
cfgFile string
|
||||||
|
tunAdapter iface.TunAdapter
|
||||||
|
recorder *peer.Status
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
ctxCancelLock *sync.Mutex
|
||||||
|
deviceName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient instantiate a new Client
|
||||||
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
|
||||||
|
lvl, _ := log.ParseLevel("trace")
|
||||||
|
log.SetLevel(lvl)
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
cfgFile: cfgFile,
|
||||||
|
deviceName: deviceName,
|
||||||
|
tunAdapter: tunAdapter,
|
||||||
|
recorder: peer.NewRecorder(""),
|
||||||
|
ctxCancelLock: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run start the internal client. It is a blocker function
|
||||||
|
func (c *Client) Run(urlOpener URLOpener) error {
|
||||||
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
|
ConfigPath: c.cfgFile,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
//nolint
|
||||||
|
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
c.ctxCancelLock.Lock()
|
||||||
|
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||||
|
defer c.ctxCancel()
|
||||||
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
|
err = auth.Login(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo do not throw error in case of cancelled context
|
||||||
|
ctx = internal.CtxInitState(ctx)
|
||||||
|
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the internal client and free the resources
|
||||||
|
func (c *Client) Stop() {
|
||||||
|
c.ctxCancelLock.Lock()
|
||||||
|
defer c.ctxCancelLock.Unlock()
|
||||||
|
if c.ctxCancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ctxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeersList return with the list of the PeerInfos
|
||||||
|
func (c *Client) PeersList() *PeerInfoArray {
|
||||||
|
|
||||||
|
fullStatus := c.recorder.GetFullStatus()
|
||||||
|
|
||||||
|
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||||
|
for n, p := range fullStatus.Peers {
|
||||||
|
pi := PeerInfo{
|
||||||
|
p.IP,
|
||||||
|
p.FQDN,
|
||||||
|
p.ConnStatus.String(),
|
||||||
|
p.Direct,
|
||||||
|
}
|
||||||
|
peerInfos[n] = pi
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PeerInfoArray{items: peerInfos}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddConnectionListener add new network connection listener
|
||||||
|
func (c *Client) AddConnectionListener(listener ConnectionListener) {
|
||||||
|
c.recorder.AddConnectionListener(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConnectionListener remove connection listener
|
||||||
|
func (c *Client) RemoveConnectionListener(listener ConnectionListener) {
|
||||||
|
c.recorder.RemoveConnectionListener(listener)
|
||||||
|
}
|
173
client/android/login.go
Normal file
173
client/android/login.go
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
|
// the backend want to show an url for the user
|
||||||
|
type URLOpener interface {
|
||||||
|
Open(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth can register or login new client
|
||||||
|
type Auth struct {
|
||||||
|
ctx context.Context
|
||||||
|
config *internal.Config
|
||||||
|
cfgPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth instantiate Auth struct and validate the management URL
|
||||||
|
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||||
|
inputCfg := internal.ConfigInput{
|
||||||
|
ManagementURL: mgmURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.CreateInMemoryConfig(inputCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
ctx: context.Background(),
|
||||||
|
config: cfg,
|
||||||
|
cfgPath: cfgPath,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthWithConfig instantiate Auth based on existing config
|
||||||
|
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
||||||
|
return &Auth{
|
||||||
|
ctx: ctx,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginAndSaveConfigIfSSOSupported test the connectivity with the management server.
|
||||||
|
// If the SSO is supported than save the configuration. Return with the SSO login is supported or not.
|
||||||
|
func (a *Auth) LoginAndSaveConfigIfSSOSupported() (bool, error) {
|
||||||
|
var needsLogin bool
|
||||||
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
if !needsLogin {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||||
|
return needsLogin, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||||
|
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error {
|
||||||
|
err := a.withBackOff(a.ctx, func() error {
|
||||||
|
err := internal.Login(a.ctx, a.config, setupKey, "")
|
||||||
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.WriteOutConfig(a.cfgPath, a.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login try register the client on the server
|
||||||
|
func (a *Auth) Login(urlOpener URLOpener) error {
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken := ""
|
||||||
|
if needsLogin {
|
||||||
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
|
}
|
||||||
|
jwtToken = tokenInfo.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||||
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostedClient := internal.NewHostedDeviceFlow(
|
||||||
|
providerConfig.ProviderConfig.Audience,
|
||||||
|
providerConfig.ProviderConfig.ClientID,
|
||||||
|
providerConfig.ProviderConfig.TokenEndpoint,
|
||||||
|
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||||
|
|
||||||
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
bf,
|
||||||
|
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||||
|
})
|
||||||
|
}
|
37
client/android/peer_notifier.go
Normal file
37
client/android/peer_notifier.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
type PeerInfo struct {
|
||||||
|
IP string
|
||||||
|
FQDN string
|
||||||
|
ConnStatus string // Todo replace to enum
|
||||||
|
Direct bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
|
type PeerInfoCollection interface {
|
||||||
|
Add(s string) PeerInfoCollection
|
||||||
|
Get(i int) string
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||||
|
type PeerInfoArray struct {
|
||||||
|
items []PeerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new PeerInfo to the collection
|
||||||
|
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array PeerInfoArray) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
78
client/android/preferences.go
Normal file
78
client/android/preferences.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Preferences export a subset of the internal config for gomobile
|
||||||
|
type Preferences struct {
|
||||||
|
configInput internal.ConfigInput
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreferences create new Preferences instance
|
||||||
|
func NewPreferences(configPath string) *Preferences {
|
||||||
|
ci := internal.ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
}
|
||||||
|
return &Preferences{ci}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetManagementURL read url from config file
|
||||||
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
|
if p.configInput.ManagementURL != "" {
|
||||||
|
return p.configInput.ManagementURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.ManagementURL.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetManagementURL store the given url and wait for commit
|
||||||
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
|
p.configInput.ManagementURL = url
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdminURL read url from config file
|
||||||
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
|
if p.configInput.AdminURL != "" {
|
||||||
|
return p.configInput.AdminURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.AdminURL.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAdminURL store the given url and wait for commit
|
||||||
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
|
p.configInput.AdminURL = url
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPreSharedKey read preshared key from config file
|
||||||
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
|
if p.configInput.PreSharedKey != nil {
|
||||||
|
return *p.configInput.PreSharedKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.PreSharedKey, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPreSharedKey store the given key and wait for commit
|
||||||
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
|
p.configInput.PreSharedKey = &key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit write out the changes into config file
|
||||||
|
func (p *Preferences) Commit() error {
|
||||||
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
|
return err
|
||||||
|
}
|
120
client/android/preferences_test.go
Normal file
120
client/android/preferences_test.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreferences_DefaultValues(t *testing.T) {
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
defaultVar, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default value: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultVar != internal.DefaultAdminURL {
|
||||||
|
t.Errorf("invalid default admin url: %s", defaultVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultVar, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default management URL: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultVar != internal.DefaultManagementURL {
|
||||||
|
t.Errorf("invalid default management url: %s", defaultVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
var preSharedKey string
|
||||||
|
preSharedKey, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preSharedKey != "" {
|
||||||
|
t.Errorf("invalid preshared key: %s", preSharedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||||
|
exampleString := "exampleString"
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
|
||||||
|
p.SetAdminURL(exampleString)
|
||||||
|
resp, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected admin url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SetManagementURL(exampleString)
|
||||||
|
resp, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SetPreSharedKey(exampleString)
|
||||||
|
resp, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected preshared key: %s", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreferences_Commit(t *testing.T) {
|
||||||
|
exampleURL := "https://myurl.com:443"
|
||||||
|
examplePresharedKey := "topsecret"
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
|
||||||
|
p.SetAdminURL(exampleURL)
|
||||||
|
p.SetManagementURL(exampleURL)
|
||||||
|
p.SetPreSharedKey(examplePresharedKey)
|
||||||
|
|
||||||
|
err := p.Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to save changes: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p = NewPreferences(cfgFile)
|
||||||
|
resp, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleURL {
|
||||||
|
t.Errorf("unexpected admin url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleURL {
|
||||||
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != examplePresharedKey {
|
||||||
|
t.Errorf("unexpected preshared key: %s", resp)
|
||||||
|
}
|
||||||
|
}
|
@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
@ -73,6 +73,25 @@ type Config struct {
|
|||||||
CustomDNSAddress string
|
CustomDNSAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
|
if configFileIsExists(configPath) {
|
||||||
|
config := &Config{}
|
||||||
|
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = WriteOutConfig(configPath, cfg)
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !configFileIsExists(input.ConfigPath) {
|
||||||
@ -86,7 +105,12 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
|||||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !configFileIsExists(input.ConfigPath) {
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
return createNewConfig(input)
|
cfg, err := createNewConfig(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||||
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||||
@ -95,6 +119,16 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
return update(input)
|
return update(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||||
|
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||||
|
return createNewConfig(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
|
return util.WriteJson(path, config)
|
||||||
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||||
wgKey := generateKey()
|
wgKey := generateKey()
|
||||||
@ -146,12 +180,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
config.IFaceBlackList = defaultInterfaceBlacklist
|
||||||
|
|
||||||
err = util.WriteJson(input.ConfigPath, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@ -60,6 +60,8 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
|||||||
|
|
||||||
statusRecorder.MarkManagementDisconnected()
|
statusRecorder.MarkManagementDisconnected()
|
||||||
|
|
||||||
|
statusRecorder.ClientStart()
|
||||||
|
defer statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
select {
|
select {
|
||||||
@ -144,7 +146,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
|||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@ -191,11 +193,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) {
|
||||||
|
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: peerConfig.Address,
|
WgAddr: peerConfig.Address,
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
|
@ -57,6 +57,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
return DeviceAuthorizationFlow{}, err
|
return DeviceAuthorizationFlow{}, err
|
||||||
}
|
}
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type registrationMap map[string]struct{}
|
||||||
|
|
||||||
type localResolver struct {
|
type localResolver struct {
|
||||||
registeredMap registrationMap
|
registeredMap registrationMap
|
||||||
records sync.Map
|
records sync.Map
|
||||||
|
@ -1,27 +1,6 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultPort = 53
|
|
||||||
customPort = 5053
|
|
||||||
defaultIP = "127.0.0.1"
|
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
@ -29,444 +8,3 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
|
||||||
type DefaultServer struct {
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
upstreamCtxCancel context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registrationMap
|
|
||||||
localResolver *localResolver
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
hostManager hostManager
|
|
||||||
updateSerial uint64
|
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
|
||||||
currentConfig hostDNSConfig
|
|
||||||
customAddress *netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
|
||||||
|
|
||||||
type muxUpdate struct {
|
|
||||||
domain string
|
|
||||||
handler dns.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
|
||||||
if customAddress != "" {
|
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
|
||||||
}
|
|
||||||
addrPort = &parsedAddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: stop,
|
|
||||||
server: dnsServer,
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registrationMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
runtimePort: defaultPort,
|
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
hostmanager, err := newHostManager(wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defaultServer.hostManager = hostmanager
|
|
||||||
return defaultServer, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs the listener in a go routine
|
|
||||||
func (s *DefaultServer) Start() {
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
} else {
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *DefaultServer) Stop() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
s.ctxCancel()
|
|
||||||
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
|
||||||
if !s.listenerIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
|
||||||
return s.ctx.Err()
|
|
||||||
default:
|
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
|
||||||
ZeroNil: true,
|
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|
||||||
// is the service should be disabled, we stop the listener
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if !update.ServiceEnable {
|
|
||||||
err := s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: customZone.Domain,
|
|
||||||
handler: s.localResolver,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
|
||||||
}
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
|
||||||
localRecords[key] = record
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, localRecords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
|
||||||
// clean up the previous upstream resolver
|
|
||||||
if s.upstreamCtxCancel != nil {
|
|
||||||
s.upstreamCtxCancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
|
||||||
if len(nsGroup.NameServers) == 0 {
|
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var ctx context.Context
|
|
||||||
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
|
||||||
|
|
||||||
handler := newUpstreamResolver(ctx)
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
|
||||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
|
||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// when upstream fails to resolve domain several times over all it servers
|
|
||||||
// it will calls this hook to exclude self from the configuration and
|
|
||||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
|
||||||
// because it is temporal deactivation until next try
|
|
||||||
//
|
|
||||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|
||||||
muxUpdateMap := make(registrationMap)
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.registerMux(update.domain, update.handler)
|
|
||||||
muxUpdateMap[update.domain] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for key := range s.dnsMuxMap {
|
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
|
||||||
s.deregisterMux(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for key, record := range update {
|
|
||||||
err := s.localResolver.registerRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
|
||||||
}
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
|
||||||
func (s *DefaultServer) upstreamCallbacks(
|
|
||||||
nsGroup *nbdns.NameServerGroup,
|
|
||||||
handler dns.Handler,
|
|
||||||
) (deactivate func(), reactivate func()) {
|
|
||||||
var removeIndex map[string]int
|
|
||||||
deactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Info("temporary deactivate nameservers group due timeout")
|
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
removeIndex[domain] = -1
|
|
||||||
}
|
|
||||||
if nsGroup.Primary {
|
|
||||||
removeIndex[nbdns.RootZone] = -1
|
|
||||||
s.currentConfig.routeAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, item := range s.currentConfig.domains {
|
|
||||||
if _, found := removeIndex[item.domain]; found {
|
|
||||||
s.currentConfig.domains[i].disabled = true
|
|
||||||
s.deregisterMux(item.domain)
|
|
||||||
removeIndex[item.domain] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
reactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
for domain, i := range removeIndex {
|
|
||||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.currentConfig.domains[i].disabled = false
|
|
||||||
s.registerMux(domain, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Debug("reactivate temporary disabled nameserver group")
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
s.currentConfig.routeAll = true
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
32
client/internal/dns/server_android.go
Normal file
32
client/internal/dns/server_android.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dummy dns server
|
||||||
|
type DefaultServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
return &DefaultServer{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start dummy implementation
|
||||||
|
func (s DefaultServer) Start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop dummy implementation
|
||||||
|
func (s DefaultServer) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer dummy implementation
|
||||||
|
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
return nil
|
||||||
|
}
|
465
client/internal/dns/server_nonandroid.go
Normal file
465
client/internal/dns/server_nonandroid.go
Normal file
@ -0,0 +1,465 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dns server object
|
||||||
|
type DefaultServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
upstreamCtxCancel context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
dnsMuxMap registrationMap
|
||||||
|
localResolver *localResolver
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
hostManager hostManager
|
||||||
|
updateSerial uint64
|
||||||
|
listenerIsRunning bool
|
||||||
|
runtimePort int
|
||||||
|
runtimeIP string
|
||||||
|
previousConfigHash uint64
|
||||||
|
currentConfig hostDNSConfig
|
||||||
|
customAddress *netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxUpdate struct {
|
||||||
|
domain string
|
||||||
|
handler dns.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer returns a new dns server
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
var addrPort *netip.AddrPort
|
||||||
|
if customAddress != "" {
|
||||||
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||||
|
}
|
||||||
|
addrPort = &parsedAddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultServer := &DefaultServer{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: stop,
|
||||||
|
server: dnsServer,
|
||||||
|
dnsMux: mux,
|
||||||
|
dnsMuxMap: make(registrationMap),
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
customAddress: addrPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
hostmanager, err := newHostManager(wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultServer.hostManager = hostmanager
|
||||||
|
return defaultServer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start runs the listener in a go routine
|
||||||
|
func (s *DefaultServer) Start() {
|
||||||
|
if s.customAddress != nil {
|
||||||
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
|
} else {
|
||||||
|
ip, port, err := s.getFirstListenerAvailable()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeIP = ip
|
||||||
|
s.runtimePort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the server
|
||||||
|
func (s *DefaultServer) Stop() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.ctxCancel()
|
||||||
|
|
||||||
|
err := s.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) stopListener() error {
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer processes an update received from the management service
|
||||||
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
log.Infof("not updating DNS server as context is closed")
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
if serial < s.updateSerial {
|
||||||
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
|
s.updateSerial = serial
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
|
// is the service should be disabled, we stop the listener
|
||||||
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
if !update.ServiceEnable {
|
||||||
|
err := s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
} else if !s.listenerIsRunning {
|
||||||
|
s.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
|
s.updateMux(muxUpdates)
|
||||||
|
s.updateLocalResolver(localRecords)
|
||||||
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
|
if len(customZone.Records) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: customZone.Domain,
|
||||||
|
handler: s.localResolver,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, record := range customZone.Records {
|
||||||
|
var class uint16 = dns.ClassINET
|
||||||
|
if record.Class != nbdns.DefaultClass {
|
||||||
|
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||||
|
}
|
||||||
|
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||||
|
localRecords[key] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, localRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||||
|
// clean up the previous upstream resolver
|
||||||
|
if s.upstreamCtxCancel != nil {
|
||||||
|
s.upstreamCtxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
for _, nsGroup := range nameServerGroups {
|
||||||
|
if len(nsGroup.NameServers) == 0 {
|
||||||
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
handler := newUpstreamResolver(ctx)
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
|
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||||
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(handler.upstreamServers) == 0 {
|
||||||
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// when upstream fails to resolve domain several times over all it servers
|
||||||
|
// it will calls this hook to exclude self from the configuration and
|
||||||
|
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||||
|
// because it is temporal deactivation until next try
|
||||||
|
//
|
||||||
|
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||||
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: domain,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
|
muxUpdateMap := make(registrationMap)
|
||||||
|
|
||||||
|
for _, update := range muxUpdates {
|
||||||
|
s.registerMux(update.domain, update.handler)
|
||||||
|
muxUpdateMap[update.domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range s.dnsMuxMap {
|
||||||
|
_, found := muxUpdateMap[key]
|
||||||
|
if !found {
|
||||||
|
s.deregisterMux(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
|
for key := range s.localResolver.registeredMap {
|
||||||
|
_, found := update[key]
|
||||||
|
if !found {
|
||||||
|
s.localResolver.deleteRecord(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedMap := make(registrationMap)
|
||||||
|
for key, record := range update {
|
||||||
|
err := s.localResolver.registerRecord(record)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||||
|
}
|
||||||
|
updatedMap[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.localResolver.registeredMap = updatedMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
|
nsGroup *nbdns.NameServerGroup,
|
||||||
|
handler dns.Handler,
|
||||||
|
) (deactivate func(), reactivate func()) {
|
||||||
|
var removeIndex map[string]int
|
||||||
|
deactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Info("temporary deactivate nameservers group due timeout")
|
||||||
|
|
||||||
|
removeIndex = make(map[string]int)
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
removeIndex[domain] = -1
|
||||||
|
}
|
||||||
|
if nsGroup.Primary {
|
||||||
|
removeIndex[nbdns.RootZone] = -1
|
||||||
|
s.currentConfig.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range s.currentConfig.domains {
|
||||||
|
if _, found := removeIndex[item.domain]; found {
|
||||||
|
s.currentConfig.domains[i].disabled = true
|
||||||
|
s.deregisterMux(item.domain)
|
||||||
|
removeIndex[item.domain] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for domain, i := range removeIndex {
|
||||||
|
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.currentConfig.domains[i].disabled = false
|
||||||
|
s.registerMux(domain, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
s.currentConfig.routeAll = true
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -199,7 +199,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU)
|
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -46,6 +46,8 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
|||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
|
// TunAdapter is option. It is necessary for mobile version.
|
||||||
|
TunAdapter iface.TunAdapter
|
||||||
|
|
||||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||||
WgAddr string
|
WgAddr string
|
||||||
@ -173,7 +175,7 @@ func (e *Engine) Start() error {
|
|||||||
myPrivateKey := e.config.WgPrivateKey
|
myPrivateKey := e.config.WgPrivateKey
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
||||||
return err
|
return err
|
||||||
@ -614,6 +616,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
if protoDNSConfig == nil {
|
if protoDNSConfig == nil {
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
|
@ -207,7 +207,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil)
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@ -549,7 +549,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
@ -714,7 +714,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
|
@ -2,37 +2,26 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
// IsLoginRequired check that the server is support SSO or not
|
||||||
// validate our peer's Wireguard PRIVATE key
|
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
||||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if config.ManagementURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -42,40 +31,84 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login or register the client
|
||||||
|
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||||
|
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
err = mgmClient.Close()
|
||||||
|
if err != nil {
|
||||||
|
cStatus, ok := status.FromError(err)
|
||||||
|
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||||
|
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
|
||||||
if err != nil {
|
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||||
log.Errorf("failed logging-in peer on Management Service : %v", err)
|
if isRegistrationNeeded(err) {
|
||||||
|
log.Debugf("peer registration required")
|
||||||
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Infof("peer has successfully logged-in to the Management service %s", config.ManagementURL.String())
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||||
func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
// validate our peer's Wireguard PRIVATE key
|
||||||
sysInfo := system.GetInfo(ctx)
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||||
log.Debugf("peer registration required")
|
return nil, err
|
||||||
return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey)
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return loginResp, nil
|
var mgmTlsEnabled bool
|
||||||
|
if mgmURL.Scheme == "https" {
|
||||||
|
mgmTlsEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mgmClient, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
||||||
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||||
|
return serverKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
@ -98,3 +131,31 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
|
|
||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.Code() == codes.PermissionDenied {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
|
||||||
@ -161,7 +162,10 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
failedTimeout := 6 * time.Second
|
failedTimeout := 6 * time.Second
|
||||||
var err error
|
transportNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||||
|
}
|
||||||
agentConfig := &ice.AgentConfig{
|
agentConfig := &ice.AgentConfig{
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||||
@ -172,6 +176,7 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
UDPMux: conn.config.UDPMux,
|
UDPMux: conn.config.UDPMux,
|
||||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||||
|
Net: transportNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.config.DisableIPv6Discovery {
|
if conn.config.DisableIPv6Discovery {
|
||||||
|
9
client/internal/peer/listener.go
Normal file
9
client/internal/peer/listener.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
// Listener is a callback type about the NetBird network connection state
|
||||||
|
type Listener interface {
|
||||||
|
OnConnected()
|
||||||
|
OnDisconnected()
|
||||||
|
OnConnecting()
|
||||||
|
OnPeersListChanged(int)
|
||||||
|
}
|
124
client/internal/peer/notifier.go
Normal file
124
client/internal/peer/notifier.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateDisconnected = iota
|
||||||
|
stateConnected
|
||||||
|
stateConnecting
|
||||||
|
)
|
||||||
|
|
||||||
|
type notifier struct {
|
||||||
|
serverStateLock sync.Mutex
|
||||||
|
listenersLock sync.Mutex
|
||||||
|
listeners map[Listener]struct{}
|
||||||
|
currentServerState bool
|
||||||
|
currentClientState bool
|
||||||
|
lastNotification int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNotifier() *notifier {
|
||||||
|
return ¬ifier{
|
||||||
|
listeners: make(map[Listener]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) addListener(listener Listener) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
go n.notifyListener(listener, n.lastNotification)
|
||||||
|
n.serverStateLock.Unlock()
|
||||||
|
n.listeners[listener] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) removeListener(listener Listener) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
delete(n.listeners, listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
|
||||||
|
var newState bool
|
||||||
|
if mgmState && signalState {
|
||||||
|
newState = true
|
||||||
|
} else {
|
||||||
|
newState = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.isServerStateChanged(newState) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.currentServerState = newState
|
||||||
|
n.lastNotification = n.calculateState(newState, n.currentClientState)
|
||||||
|
|
||||||
|
go n.notifyAll(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientStart() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = true
|
||||||
|
n.lastNotification = n.calculateState(n.currentServerState, true)
|
||||||
|
go n.notifyAll(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientStop() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = false
|
||||||
|
n.lastNotification = n.calculateState(n.currentServerState, false)
|
||||||
|
go n.notifyAll(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) isServerStateChanged(newState bool) bool {
|
||||||
|
return n.currentServerState != newState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) notifyAll(state int) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
|
||||||
|
for l := range n.listeners {
|
||||||
|
n.notifyListener(l, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) notifyListener(l Listener, state int) {
|
||||||
|
switch state {
|
||||||
|
case stateDisconnected:
|
||||||
|
l.OnDisconnected()
|
||||||
|
case stateConnected:
|
||||||
|
l.OnConnected()
|
||||||
|
case stateConnecting:
|
||||||
|
l.OnConnecting()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) calculateState(serverState bool, clientState bool) int {
|
||||||
|
if serverState && clientState {
|
||||||
|
return stateConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if !clientState {
|
||||||
|
return stateDisconnected
|
||||||
|
}
|
||||||
|
|
||||||
|
return stateConnecting
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
|
||||||
|
for l := range n.listeners {
|
||||||
|
l.OnPeersListChanged(numOfPeers)
|
||||||
|
}
|
||||||
|
}
|
32
client/internal/peer/notifier_test.go
Normal file
32
client/internal/peer/notifier_test.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_notifier_serverState(t *testing.T) {
|
||||||
|
|
||||||
|
type scenario struct {
|
||||||
|
name string
|
||||||
|
expected bool
|
||||||
|
mgmState bool
|
||||||
|
signalState bool
|
||||||
|
}
|
||||||
|
scenarios := []scenario{
|
||||||
|
{"connected", true, true, true},
|
||||||
|
{"mgm down", false, false, true},
|
||||||
|
{"signal down", false, true, false},
|
||||||
|
{"disconnected", false, false, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range scenarios {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
n := newNotifier()
|
||||||
|
n.updateServerStates(tt.mgmState, tt.signalState)
|
||||||
|
if n.currentServerState != tt.expected {
|
||||||
|
t.Errorf("invalid serverstate: %t, expected: %t", n.currentServerState, tt.expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -58,6 +58,7 @@ type Status struct {
|
|||||||
offlinePeers []State
|
offlinePeers []State
|
||||||
mgmAddress string
|
mgmAddress string
|
||||||
signalAddress string
|
signalAddress string
|
||||||
|
notifier *notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
// NewRecorder returns a new Status instance
|
||||||
@ -66,6 +67,7 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
peers: make(map[string]State),
|
peers: make(map[string]State),
|
||||||
changeNotify: make(map[string]chan struct{}),
|
changeNotify: make(map[string]chan struct{}),
|
||||||
offlinePeers: make([]State, 0),
|
offlinePeers: make([]State, 0),
|
||||||
|
notifier: newNotifier(),
|
||||||
mgmAddress: mgmAddress,
|
mgmAddress: mgmAddress,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -114,6 +116,7 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
return errors.New("no peer with to remove")
|
return errors.New("no peer with to remove")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,6 +151,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
d.changeNotify[receivedState.PubKey] = nil
|
d.changeNotify[receivedState.PubKey] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,6 +168,7 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
peerState.FQDN = fqdn
|
peerState.FQDN = fqdn
|
||||||
d.peers[peerPubKey] = peerState
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,6 +204,8 @@ func (d *Status) CleanLocalPeerState() {
|
|||||||
func (d *Status) MarkManagementDisconnected() {
|
func (d *Status) MarkManagementDisconnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = false
|
d.managementState = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,7 +213,9 @@ func (d *Status) MarkManagementDisconnected() {
|
|||||||
func (d *Status) MarkManagementConnected() {
|
func (d *Status) MarkManagementConnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.managementState = true
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.managementState = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSignalAddress update the address of the signal server
|
// UpdateSignalAddress update the address of the signal server
|
||||||
@ -227,13 +236,17 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
|||||||
func (d *Status) MarkSignalDisconnected() {
|
func (d *Status) MarkSignalDisconnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.signalState = false
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.signalState = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
// MarkSignalConnected sets SignalState to connected
|
||||||
func (d *Status) MarkSignalConnected() {
|
func (d *Status) MarkSignalConnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = true
|
d.signalState = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,3 +275,31 @@ func (d *Status) GetFullStatus() FullStatus {
|
|||||||
|
|
||||||
return fullStatus
|
return fullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientStart will notify all listeners about the new service state
|
||||||
|
func (d *Status) ClientStart() {
|
||||||
|
d.notifier.clientStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientStop will notify all listeners about the new service state
|
||||||
|
func (d *Status) ClientStop() {
|
||||||
|
d.notifier.clientStop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddConnectionListener add a listener to the notifier
|
||||||
|
func (d *Status) AddConnectionListener(listener Listener) {
|
||||||
|
d.notifier.addListener(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConnectionListener remove a listener from the notifier
|
||||||
|
func (d *Status) RemoveConnectionListener(listener Listener) {
|
||||||
|
d.notifier.removeListener(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) onConnectionChanged() {
|
||||||
|
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyPeerListChanged() {
|
||||||
|
d.notifier.peerListChanged(len(d.peers))
|
||||||
|
}
|
||||||
|
@ -1,190 +1,9 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import "github.com/netbirdio/netbird/route"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
|
||||||
type DefaultManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
clientNetworks map[string]*clientNetwork
|
|
||||||
serverRoutes map[string]*route.Route
|
|
||||||
serverRouter *serverRouter
|
|
||||||
statusRecorder *peer.Status
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
pubKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
|
||||||
return &DefaultManager{
|
|
||||||
ctx: mCTX,
|
|
||||||
stop: cancel,
|
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
|
||||||
serverRoutes: make(map[string]*route.Route),
|
|
||||||
serverRouter: &serverRouter{
|
|
||||||
routes: make(map[string]*route.Route),
|
|
||||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
|
||||||
firewall: NewFirewall(ctx),
|
|
||||||
},
|
|
||||||
statusRecorder: statusRecorder,
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
pubKey: pubKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
|
||||||
func (m *DefaultManager) Stop() {
|
|
||||||
m.stop()
|
|
||||||
m.serverRouter.firewall.CleanRoutingRules()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
|
||||||
for id, client := range m.clientNetworks {
|
|
||||||
_, found := networks[id]
|
|
||||||
if !found {
|
|
||||||
log.Debugf("stopping client network watcher, %s", id)
|
|
||||||
client.stop()
|
|
||||||
delete(m.clientNetworks, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, routes := range networks {
|
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
|
||||||
if !found {
|
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
||||||
}
|
|
||||||
update := routesUpdate{
|
|
||||||
updateSerial: updateSerial,
|
|
||||||
routes: routes,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
|
||||||
serverRoutesToRemove := make([]string, 0)
|
|
||||||
|
|
||||||
if len(routesMap) > 0 {
|
|
||||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID := range m.serverRoutes {
|
|
||||||
update, found := routesMap[routeID]
|
|
||||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
|
||||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, routeID := range serverRoutesToRemove {
|
|
||||||
oldRoute := m.serverRoutes[routeID]
|
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
|
||||||
}
|
|
||||||
delete(m.serverRoutes, routeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, newRoute := range routesMap {
|
|
||||||
_, found := m.serverRoutes[id]
|
|
||||||
if found {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.addToServerNetwork(newRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.serverRoutes[id] = newRoute
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.serverRoutes) > 0 {
|
|
||||||
err := enableIPForwarding()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("not updating routes as context is closed")
|
|
||||||
return m.ctx.Err()
|
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
|
||||||
ownNetworkIDs := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if newRoute.Peer == m.pubKey {
|
|
||||||
ownNetworkIDs[networkID] = true
|
|
||||||
// only linux is supported for now
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if !ownNetworkIDs[networkID] {
|
|
||||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
|
||||||
// we skip this route management
|
|
||||||
if newRoute.Network.Bits() < 7 {
|
|
||||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
|
||||||
version.NetbirdVersion(), newRoute.Network)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
|
||||||
|
|
||||||
err := m.updateServerRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
31
client/internal/routemanager/manager_android.go
Normal file
31
client/internal/routemanager/manager_android.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultManager dummy router manager for Android
|
||||||
|
type DefaultManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
serverRouter *serverRouter
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager returns a new dummy route manager what doing nothing
|
||||||
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
|
return &DefaultManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoutes ...
|
||||||
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop ...
|
||||||
|
func (m *DefaultManager) Stop() {
|
||||||
|
|
||||||
|
}
|
186
client/internal/routemanager/manager_nonandroid.go
Normal file
186
client/internal/routemanager/manager_nonandroid.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultManager is the default instance of a route manager
|
||||||
|
type DefaultManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
clientNetworks map[string]*clientNetwork
|
||||||
|
serverRoutes map[string]*route.Route
|
||||||
|
serverRouter *serverRouter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
pubKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager returns a new route manager
|
||||||
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
|
return &DefaultManager{
|
||||||
|
ctx: mCTX,
|
||||||
|
stop: cancel,
|
||||||
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
|
serverRoutes: make(map[string]*route.Route),
|
||||||
|
serverRouter: &serverRouter{
|
||||||
|
routes: make(map[string]*route.Route),
|
||||||
|
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||||
|
firewall: NewFirewall(ctx),
|
||||||
|
},
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
pubKey: pubKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
|
func (m *DefaultManager) Stop() {
|
||||||
|
m.stop()
|
||||||
|
m.serverRouter.firewall.CleanRoutingRules()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||||
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
|
for id, client := range m.clientNetworks {
|
||||||
|
_, found := networks[id]
|
||||||
|
if !found {
|
||||||
|
log.Debugf("stopping client network watcher, %s", id)
|
||||||
|
client.stop()
|
||||||
|
delete(m.clientNetworks, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range networks {
|
||||||
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
|
if !found {
|
||||||
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||||
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
|
}
|
||||||
|
update := routesUpdate{
|
||||||
|
updateSerial: updateSerial,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
|
||||||
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||||
|
serverRoutesToRemove := make([]string, 0)
|
||||||
|
|
||||||
|
if len(routesMap) > 0 {
|
||||||
|
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for routeID := range m.serverRoutes {
|
||||||
|
update, found := routesMap[routeID]
|
||||||
|
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||||
|
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, routeID := range serverRoutesToRemove {
|
||||||
|
oldRoute := m.serverRoutes[routeID]
|
||||||
|
err := m.removeFromServerNetwork(oldRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
|
}
|
||||||
|
delete(m.serverRoutes, routeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, newRoute := range routesMap {
|
||||||
|
_, found := m.serverRoutes[id]
|
||||||
|
if found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.addToServerNetwork(newRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.serverRoutes[id] = newRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.serverRoutes) > 0 {
|
||||||
|
err := enableIPForwarding()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||||
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
log.Infof("not updating routes as context is closed")
|
||||||
|
return m.ctx.Err()
|
||||||
|
default:
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||||
|
newServerRoutesMap := make(map[string]*route.Route)
|
||||||
|
ownNetworkIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if newRoute.Peer == m.pubKey {
|
||||||
|
ownNetworkIDs[networkID] = true
|
||||||
|
// only linux is supported for now
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if !ownNetworkIDs[networkID] {
|
||||||
|
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||||
|
// we skip this route management
|
||||||
|
if newRoute.Network.Bits() < 7 {
|
||||||
|
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||||
|
version.NetbirdVersion(), newRoute.Network)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||||
|
|
||||||
|
err := m.updateServerRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -391,7 +391,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil {
|
||||||
log.Errorf("init connections: %v", err)
|
log.Errorf("init connections: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -394,7 +394,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil {
|
||||||
log.Errorf("run client connection: %v", err)
|
log.Errorf("run client connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DeviceNameCtxKey context key for device name
|
||||||
|
const DeviceNameCtxKey = "deviceName"
|
||||||
|
|
||||||
// Info is an object that contains machine information
|
// Info is an object that contains machine information
|
||||||
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
||||||
type Info struct {
|
type Info struct {
|
||||||
|
63
client/system/info_android.go
Normal file
63
client/system/info_android.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
//go:build android
|
||||||
|
// +build android
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetInfo retrieves and parses the system information
|
||||||
|
func GetInfo(ctx context.Context) *Info {
|
||||||
|
kernel := "android"
|
||||||
|
osInfo := uname()
|
||||||
|
if len(osInfo) == 2 {
|
||||||
|
kernel = osInfo[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
|
gio.Hostname = extractDeviceName(ctx)
|
||||||
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
|
return gio
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDeviceName(ctx context.Context) string {
|
||||||
|
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func uname() []string {
|
||||||
|
res := run("/system/bin/uname", "-a")
|
||||||
|
return strings.Split(res, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func osVersion() string {
|
||||||
|
return run("/system/bin/getprop", "ro.build.version.release")
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(name string, arg ...string) string {
|
||||||
|
cmd := exec.Command(name, arg...)
|
||||||
|
cmd.Stdin = strings.NewReader("some")
|
||||||
|
var out bytes.Buffer
|
||||||
|
var stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &out
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("getInfo: %s", err)
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
@ -1,3 +1,6 @@
|
|||||||
|
//go:build !android
|
||||||
|
// +build !android
|
||||||
|
|
||||||
package system
|
package system
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -10,15 +10,15 @@ import (
|
|||||||
|
|
||||||
// TextFormatter formats logs into text with included source code's path
|
// TextFormatter formats logs into text with included source code's path
|
||||||
type TextFormatter struct {
|
type TextFormatter struct {
|
||||||
TimestampFormat string
|
timestampFormat string
|
||||||
LevelDesc []string
|
levelDesc []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTextFormatter create new MyTextFormatter instance
|
// NewTextFormatter create new MyTextFormatter instance
|
||||||
func NewTextFormatter() *TextFormatter {
|
func NewTextFormatter() *TextFormatter {
|
||||||
return &TextFormatter{
|
return &TextFormatter{
|
||||||
LevelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
||||||
TimestampFormat: time.RFC3339, // or RFC3339
|
timestampFormat: time.RFC3339, // or RFC3339
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,13 +39,13 @@ func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
|||||||
|
|
||||||
level := f.parseLevel(entry.Level)
|
level := f.parseLevel(entry.Level)
|
||||||
|
|
||||||
return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.TimestampFormat), level, fields, entry.Data["source"], entry.Message)), nil
|
return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.timestampFormat), level, fields, entry.Data["source"], entry.Message)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
||||||
if len(f.LevelDesc) < int(level) {
|
if len(f.levelDesc) < int(level) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.LevelDesc[level]
|
return f.levelDesc[level]
|
||||||
}
|
}
|
||||||
|
48
formatter/logcat.go
Normal file
48
formatter/logcat.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogcatFormatter formats logs into text what is fit for logcat
|
||||||
|
type LogcatFormatter struct {
|
||||||
|
levelDesc []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogcatFormatter create new LogcatFormatter instance
|
||||||
|
func NewLogcatFormatter() *LogcatFormatter {
|
||||||
|
return &LogcatFormatter{
|
||||||
|
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format renders a single log entry
|
||||||
|
func (f *LogcatFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||||
|
var fields string
|
||||||
|
keys := make([]string, 0, len(entry.Data))
|
||||||
|
for k, v := range entry.Data {
|
||||||
|
if k == "source" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) > 0 {
|
||||||
|
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
level := f.parseLevel(entry.Level)
|
||||||
|
|
||||||
|
return []byte(fmt.Sprintf("[%s] %s%s %s\n", level, fields, entry.Data["source"], entry.Message)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *LogcatFormatter) parseLevel(level logrus.Level) string {
|
||||||
|
if len(f.levelDesc) < int(level) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.levelDesc[level]
|
||||||
|
}
|
28
formatter/logcat_test.go
Normal file
28
formatter/logcat_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogcatMessageFormat(t *testing.T) {
|
||||||
|
|
||||||
|
someEntry := &logrus.Entry{
|
||||||
|
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||||
|
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
|
||||||
|
Level: 3,
|
||||||
|
Message: "Some Message",
|
||||||
|
}
|
||||||
|
|
||||||
|
formatter := NewLogcatFormatter()
|
||||||
|
result, _ := formatter.Format(someEntry)
|
||||||
|
|
||||||
|
expectedString := "[WARN] [att1: 1, att2: 2] some/fancy/path.go:46 Some Message\n"
|
||||||
|
expectedStringVariant := "[WARN] [att2: 2, att1: 1] some/fancy/path.go:46 Some Message\n"
|
||||||
|
parsedString := string(result)
|
||||||
|
if parsedString != expectedString && parsedString != expectedStringVariant {
|
||||||
|
t.Errorf("The log messages don't match. Expected: '%s', got: '%s'", expectedString, parsedString)
|
||||||
|
}
|
||||||
|
}
|
@ -2,9 +2,16 @@ package formatter
|
|||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
// SetTextFormatter set the formatter for given logger.
|
// SetTextFormatter set the text formatter for given logger.
|
||||||
func SetTextFormatter(logger *logrus.Logger) {
|
func SetTextFormatter(logger *logrus.Logger) {
|
||||||
logger.Formatter = NewTextFormatter()
|
logger.Formatter = NewTextFormatter()
|
||||||
logger.ReportCaller = true
|
logger.ReportCaller = true
|
||||||
logger.AddHook(NewContextHook())
|
logger.AddHook(NewContextHook())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogcatFormatter set the logcat formatter for given logger.
|
||||||
|
func SetLogcatFormatter(logger *logrus.Logger) {
|
||||||
|
logger.Formatter = NewLogcatFormatter()
|
||||||
|
logger.ReportCaller = true
|
||||||
|
logger.AddHook(NewContextHook())
|
||||||
|
}
|
||||||
|
2
go.mod
2
go.mod
@ -47,6 +47,7 @@ require (
|
|||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/open-policy-agent/opa v0.49.0
|
github.com/open-policy-agent/opa v0.49.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
|
github.com/pion/transport/v2 v2.0.2
|
||||||
github.com/prometheus/client_golang v1.14.0
|
github.com/prometheus/client_golang v1.14.0
|
||||||
github.com/rs/xid v1.3.0
|
github.com/rs/xid v1.3.0
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
@ -105,7 +106,6 @@ require (
|
|||||||
github.com/pion/mdns v0.0.7 // indirect
|
github.com/pion/mdns v0.0.7 // indirect
|
||||||
github.com/pion/randutil v0.1.0 // indirect
|
github.com/pion/randutil v0.1.0 // indirect
|
||||||
github.com/pion/stun v0.4.0 // indirect
|
github.com/pion/stun v0.4.0 // indirect
|
||||||
github.com/pion/transport/v2 v2.0.2 // indirect
|
|
||||||
github.com/pion/turn/v2 v2.1.0 // indirect
|
github.com/pion/turn/v2 v2.1.0 // indirect
|
||||||
github.com/pion/udp/v2 v2.0.1 // indirect
|
github.com/pion/udp/v2 v2.0.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
260
iface/iface.go
260
iface/iface.go
@ -1,13 +1,11 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,46 +14,30 @@ const (
|
|||||||
DefaultWgPort = 51820
|
DefaultWgPort = 51820
|
||||||
)
|
)
|
||||||
|
|
||||||
// NetInterface represents a generic network tunnel interface
|
|
||||||
type NetInterface interface {
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// WGIface represents a interface instance
|
// WGIface represents a interface instance
|
||||||
type WGIface struct {
|
type WGIface struct {
|
||||||
name string
|
tun *tunDevice
|
||||||
address WGAddress
|
configurer wGConfigurer
|
||||||
mtu int
|
mu sync.Mutex
|
||||||
netInterface NetInterface
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGIFace Creates a new Wireguard interface instance
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) {
|
// Will reuse an existing one.
|
||||||
wgIface := &WGIface{
|
func (w *WGIface) Create() error {
|
||||||
name: iface,
|
w.mu.Lock()
|
||||||
mtu: mtu,
|
defer w.mu.Unlock()
|
||||||
mu: sync.Mutex{},
|
log.Debugf("create Wireguard interface %s", w.tun.DeviceName())
|
||||||
}
|
return w.tun.Create()
|
||||||
|
|
||||||
wgAddress, err := parseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return wgIface, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface.address = wgAddress
|
|
||||||
|
|
||||||
return wgIface, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the interface name
|
// Name returns the interface name
|
||||||
func (w *WGIface) Name() string {
|
func (w *WGIface) Name() string {
|
||||||
return w.name
|
return w.tun.DeviceName()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the interface address
|
// Address returns the interface address
|
||||||
func (w *WGIface) Address() WGAddress {
|
func (w *WGIface) Address() WGAddress {
|
||||||
return w.address
|
return w.tun.WgAddress()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure configures a Wireguard interface
|
// Configure configures a Wireguard interface
|
||||||
@ -63,27 +45,8 @@ func (w *WGIface) Address() WGAddress {
|
|||||||
func (w *WGIface) Configure(privateKey string, port int) error {
|
func (w *WGIface) Configure(privateKey string, port int) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
log.Debugf("configuring Wireguard interface %s", w.tun.DeviceName())
|
||||||
log.Debugf("configuring Wireguard interface %s", w.name)
|
return w.configurer.configureInterface(privateKey, port)
|
||||||
|
|
||||||
log.Debugf("adding Wireguard private key")
|
|
||||||
key, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fwmark := 0
|
|
||||||
config := wgtypes.Config{
|
|
||||||
PrivateKey: &key,
|
|
||||||
ReplacePeers: true,
|
|
||||||
FirewallMark: &fwmark,
|
|
||||||
ListenPort: &port,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = w.configureDevice(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, w.name, port)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAddr updates address of the interface
|
// UpdateAddr updates address of the interface
|
||||||
@ -96,8 +59,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.address = addr
|
return w.tun.UpdateAddr(addr)
|
||||||
return w.assignAddr()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
@ -106,119 +68,8 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
|
|||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
log.Debugf("updating interface %s peer %s: endpoint %s ", w.name, peerKey, endpoint)
|
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||||
|
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
//parse allowed ips
|
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
ReplaceAllowedIPs: true,
|
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
|
||||||
PresharedKey: preSharedKey,
|
|
||||||
Endpoint: endpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
err = w.configureDevice(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, w.name, allowedIps, endpoint.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
|
||||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP)
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
UpdateOnly: true,
|
|
||||||
ReplaceAllowedIPs: false,
|
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
err = w.configureDevice(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
|
||||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP)
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
existingPeer, err := getPeer(w.name, peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
newAllowedIPs := existingPeer.AllowedIPs
|
|
||||||
|
|
||||||
for i, existingAllowedIP := range existingPeer.AllowedIPs {
|
|
||||||
if existingAllowedIP.String() == ipNet.String() {
|
|
||||||
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
UpdateOnly: true,
|
|
||||||
ReplaceAllowedIPs: true,
|
|
||||||
AllowedIPs: newAllowedIPs,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
err = w.configureDevice(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
@ -226,66 +77,31 @@ func (w *WGIface) RemovePeer(peerKey string) error {
|
|||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
log.Debugf("Removing peer %s from interface %s ", peerKey, w.name)
|
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
|
||||||
|
return w.configurer.removePeer(peerKey)
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
Remove: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
err = w.configureDevice(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, w.name)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
||||||
wg, err := wgctrl.New()
|
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||||
if err != nil {
|
w.mu.Lock()
|
||||||
return wgtypes.Peer{}, err
|
defer w.mu.Unlock()
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = wg.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got error while closing wgctl: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wgDevice, err := wg.Device(ifaceName)
|
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
if err != nil {
|
return w.configurer.addAllowedIP(peerKey, allowedIP)
|
||||||
return wgtypes.Peer{}, err
|
|
||||||
}
|
|
||||||
for _, peer := range wgDevice.Peers {
|
|
||||||
if peer.PublicKey.String() == peerPubKey {
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// configureDevice configures the wireguard device
|
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
||||||
func (w *WGIface) configureDevice(config wgtypes.Config) error {
|
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||||
wg, err := wgctrl.New()
|
w.mu.Lock()
|
||||||
if err != nil {
|
defer w.mu.Unlock()
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer wg.Close()
|
|
||||||
|
|
||||||
// validate if device with name exists
|
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
_, err = wg.Device(w.name)
|
return w.configurer.removeAllowedIP(peerKey, allowedIP)
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
// Close closes the tunnel interface
|
||||||
log.Debugf("got Wireguard device %s", w.name)
|
func (w *WGIface) Close() error {
|
||||||
|
w.mu.Lock()
|
||||||
return wg.ConfigureDevice(w.name, config)
|
defer w.mu.Unlock()
|
||||||
|
return w.tun.Close()
|
||||||
}
|
}
|
||||||
|
22
iface/iface_android.go
Normal file
22
iface/iface_android.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new Wireguard interface instance
|
||||||
|
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
|
||||||
|
wgIface := &WGIface{
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
wgAddress, err := parseWGAddress(address)
|
||||||
|
if err != nil {
|
||||||
|
return wgIface, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tun := newTunDevice(wgAddress, mtu, tunAdapter)
|
||||||
|
wgIface.tun = tun
|
||||||
|
|
||||||
|
wgIface.configurer = newWGConfigurer(tun)
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
22
iface/iface_nonandroid.go
Normal file
22
iface/iface_nonandroid.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new Wireguard interface instance
|
||||||
|
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
|
||||||
|
wgIface := &WGIface{
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
wgAddress, err := parseWGAddress(address)
|
||||||
|
if err != nil {
|
||||||
|
return wgIface, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu)
|
||||||
|
|
||||||
|
wgIface.configurer = newWGConfigurer(ifaceName)
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
@ -32,7 +32,7 @@ func init() {
|
|||||||
func TestWGIface_UpdateAddr(t *testing.T) {
|
func TestWGIface_UpdateAddr(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
addr := "100.64.0.1/8"
|
addr := "100.64.0.1/8"
|
||||||
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -92,7 +92,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
|||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||||
wgIP := "10.99.99.1/32"
|
wgIP := "10.99.99.1/32"
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -121,7 +121,7 @@ func Test_CreateInterface(t *testing.T) {
|
|||||||
func Test_Close(t *testing.T) {
|
func Test_Close(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -149,7 +149,7 @@ func Test_Close(t *testing.T) {
|
|||||||
func Test_ConfigureInterface(t *testing.T) {
|
func Test_ConfigureInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||||
wgIP := "10.99.99.5/30"
|
wgIP := "10.99.99.5/30"
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -196,7 +196,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.9/30"
|
wgIP := "10.99.99.9/30"
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
peer, err := getPeer(ifaceName, peerPubKey)
|
peer, err := iface.configurer.getPeer(ifaceName, peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -255,7 +255,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.13/30"
|
wgIP := "10.99.99.13/30"
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -288,7 +288,7 @@ func Test_RemovePeer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
_, err = getPeer(ifaceName, peerPubKey)
|
_, err = iface.configurer.getPeer(ifaceName, peerPubKey)
|
||||||
if err.Error() != "peer not found" {
|
if err.Error() != "peer not found" {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -305,7 +305,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
|
|
||||||
keepAlive := 1 * time.Second
|
keepAlive := 1 * time.Second
|
||||||
|
|
||||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU)
|
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -322,7 +322,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU)
|
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -375,7 +375,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
t.Fatalf("waiting for peer handshake timeout after %s", timeout.String())
|
t.Fatalf("waiting for peer handshake timeout after %s", timeout.String())
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String())
|
peer, gpErr := iface1.configurer.getPeer(peer1ifaceName, peer2Key.PublicKey().String())
|
||||||
if gpErr != nil {
|
if gpErr != nil {
|
||||||
t.Fatal(gpErr)
|
t.Fatal(gpErr)
|
||||||
}
|
}
|
||||||
|
@ -1,69 +1,6 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
"golang.zx2c4.com/wireguard/windows/driver"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
|
|
||||||
func (w *WGIface) Create() error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
|
|
||||||
adapter, err := driver.CreateAdapter(w.name, "WireGuard", &WintunStaticRequestedGUID)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("error creating adapter: %w", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
w.netInterface = adapter
|
|
||||||
err = adapter.SetAdapterState(driver.AdapterStateUp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state, _ := adapter.LUID().GUID()
|
|
||||||
log.Debugln("device guid: ", state.String())
|
|
||||||
return w.assignAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInterfaceGUIDString returns an interface GUID string
|
|
||||||
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
if w.netInterface == nil {
|
return w.tun.getInterfaceGUIDString()
|
||||||
return "", fmt.Errorf("interface has not been initialized yet")
|
|
||||||
}
|
|
||||||
windowsDevice := w.netInterface.(*driver.Adapter)
|
|
||||||
luid := windowsDevice.LUID()
|
|
||||||
guid, err := luid.GUID()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return guid.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the tunnel interface
|
|
||||||
func (w *WGIface) Close() error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
if w.netInterface == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.netInterface.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
|
||||||
func (w *WGIface) assignAddr() error {
|
|
||||||
luid := w.netInterface.(*driver.Adapter).LUID()
|
|
||||||
|
|
||||||
log.Debugf("adding address %s to interface: %s", w.address.IP, w.name)
|
|
||||||
err := luid.SetIPAddresses([]net.IPNet{{w.address.IP, w.address.Network.Mask}})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
60
iface/ipc_parser_android.go
Normal file
60
iface/ipc_parser_android.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if wgCfg.PrivateKey != nil {
|
||||||
|
hexKey := hex.EncodeToString(wgCfg.PrivateKey[:])
|
||||||
|
sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
if wgCfg.ListenPort != nil {
|
||||||
|
sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
if wgCfg.ReplacePeers {
|
||||||
|
sb.WriteString("replace_peers=true\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wgCfg.FirewallMark != nil {
|
||||||
|
sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range wgCfg.Peers {
|
||||||
|
hexKey := hex.EncodeToString(p.PublicKey[:])
|
||||||
|
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
||||||
|
|
||||||
|
if p.PresharedKey != nil {
|
||||||
|
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||||
|
sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Remove {
|
||||||
|
sb.WriteString("remove=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ReplaceAllowedIPs {
|
||||||
|
sb.WriteString("replace_allowed_ips=true\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aip := range p.AllowedIPs {
|
||||||
|
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Endpoint != nil {
|
||||||
|
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.PersistentKeepaliveInterval != nil {
|
||||||
|
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
//go:build !linux
|
//go:build !linux || android
|
||||||
// +build !linux
|
// +build !linux android
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
// Package iface provides wireguard network interface creation and management
|
// Package iface provides wireguard network interface creation and management
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
|
6
iface/tun.go
Normal file
6
iface/tun.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
// NetInterface represents a generic network tunnel interface
|
||||||
|
type NetInterface interface {
|
||||||
|
Close() error
|
||||||
|
}
|
7
iface/tun_adapter.go
Normal file
7
iface/tun_adapter.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
// TunAdapter is an interface for create tun device from externel service
|
||||||
|
type TunAdapter interface {
|
||||||
|
ConfigureInterface(address string, mtu int) (int, error)
|
||||||
|
UpdateAddr(address string) error
|
||||||
|
}
|
112
iface/tun_android.go
Normal file
112
iface/tun_android.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tunDevice struct {
|
||||||
|
address WGAddress
|
||||||
|
mtu int
|
||||||
|
tunAdapter TunAdapter
|
||||||
|
|
||||||
|
fd int
|
||||||
|
name string
|
||||||
|
device *device.Device
|
||||||
|
uapi net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice {
|
||||||
|
return &tunDevice{
|
||||||
|
address: address,
|
||||||
|
mtu: mtu,
|
||||||
|
tunAdapter: tunAdapter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) Create() error {
|
||||||
|
var err error
|
||||||
|
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(t.fd)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(t.fd)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.name = name
|
||||||
|
|
||||||
|
log.Debugf("attaching to interface %v", name)
|
||||||
|
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||||
|
t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
|
log.Debugf("create uapi")
|
||||||
|
tunSock, err := ipc.UAPIOpen(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.uapi, err = ipc.UAPIListen(name, tunSock)
|
||||||
|
if err != nil {
|
||||||
|
tunSock.Close()
|
||||||
|
unix.Close(t.fd)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
uapiConn, err := t.uapi.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go t.device.IpcHandle(uapiConn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = t.device.Up()
|
||||||
|
if err != nil {
|
||||||
|
tunSock.Close()
|
||||||
|
t.device.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("device is ready to use: %s", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) DeviceName() string {
|
||||||
|
return t.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) WgAddress() WGAddress {
|
||||||
|
return t.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) UpdateAddr(addr WGAddress) error {
|
||||||
|
// todo implement
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) Close() (err error) {
|
||||||
|
if t.uapi != nil {
|
||||||
|
err = t.uapi.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.device != nil {
|
||||||
|
t.device.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -6,23 +6,25 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
|
func (c *tunDevice) Create() error {
|
||||||
func (w *WGIface) Create() error {
|
var err error
|
||||||
w.mu.Lock()
|
c.netInterface, err = c.createWithUserspace()
|
||||||
defer w.mu.Unlock()
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return w.createWithUserspace()
|
return c.assignAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (w *WGIface) assignAddr() error {
|
func (c *tunDevice) assignAddr() error {
|
||||||
cmd := exec.Command("ifconfig", w.name, "inet", w.address.IP.String(), w.address.IP.String())
|
cmd := exec.Command("ifconfig", c.name, "inet", c.address.IP.String(), c.address.IP.String())
|
||||||
if out, err := cmd.CombinedOutput(); err != nil {
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out)
|
log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routeCmd := exec.Command("route", "add", "-net", w.address.Network.String(), "-interface", w.name)
|
routeCmd := exec.Command("route", "add", "-net", c.address.Network.String(), "-interface", c.name)
|
||||||
if out, err := routeCmd.CombinedOutput(); err != nil {
|
if out, err := routeCmd.CombinedOutput(); err != nil {
|
||||||
log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out)
|
log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out)
|
||||||
return err
|
return err
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -8,32 +10,34 @@ import (
|
|||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
func (c *tunDevice) Create() error {
|
||||||
// Will reuse an existing one.
|
|
||||||
func (w *WGIface) Create() error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if WireguardModuleIsLoaded() {
|
if WireguardModuleIsLoaded() {
|
||||||
log.Info("using kernel WireGuard")
|
log.Info("using kernel WireGuard")
|
||||||
return w.createWithKernel()
|
return c.createWithKernel()
|
||||||
} else {
|
|
||||||
if !tunModuleIsLoaded() {
|
|
||||||
return fmt.Errorf("couldn't check or load tun module")
|
|
||||||
}
|
|
||||||
log.Info("using userspace WireGuard")
|
|
||||||
return w.createWithUserspace()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !tunModuleIsLoaded() {
|
||||||
|
return fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
|
log.Info("using userspace WireGuard")
|
||||||
|
var err error
|
||||||
|
c.netInterface, err = c.createWithUserspace()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.assignAddr()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
|
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
|
||||||
// Works for Linux and offers much better network performance
|
// Works for Linux and offers much better network performance
|
||||||
func (w *WGIface) createWithKernel() error {
|
func (c *tunDevice) createWithKernel() error {
|
||||||
|
|
||||||
link := newWGLink(w.name)
|
link := newWGLink(c.name)
|
||||||
|
|
||||||
// check if interface exists
|
// check if interface exists
|
||||||
l, err := netlink.LinkByName(w.name)
|
l, err := netlink.LinkByName(c.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case netlink.LinkNotFoundError:
|
case netlink.LinkNotFoundError:
|
||||||
@ -51,33 +55,33 @@ func (w *WGIface) createWithKernel() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("adding device: %s", w.name)
|
log.Debugf("adding device: %s", c.name)
|
||||||
err = netlink.LinkAdd(link)
|
err = netlink.LinkAdd(link)
|
||||||
if os.IsExist(err) {
|
if os.IsExist(err) {
|
||||||
log.Infof("interface %s already exists. Will reuse.", w.name)
|
log.Infof("interface %s already exists. Will reuse.", c.name)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.netInterface = link
|
c.netInterface = link
|
||||||
|
|
||||||
err = w.assignAddr()
|
err = c.assignAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo do a discovery
|
// todo do a discovery
|
||||||
log.Debugf("setting MTU: %d interface: %s", w.mtu, w.name)
|
log.Debugf("setting MTU: %d interface: %s", c.mtu, c.name)
|
||||||
err = netlink.LinkSetMTU(link, w.mtu)
|
err = netlink.LinkSetMTU(link, c.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error setting MTU on interface: %s", w.name)
|
log.Errorf("error setting MTU on interface: %s", c.name)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("bringing up interface: %s", w.name)
|
log.Debugf("bringing up interface: %s", c.name)
|
||||||
err = netlink.LinkSetUp(link)
|
err = netlink.LinkSetUp(link)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error bringing up interface: %s", w.name)
|
log.Errorf("error bringing up interface: %s", c.name)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,8 +89,8 @@ func (w *WGIface) createWithKernel() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (w *WGIface) assignAddr() error {
|
func (c *tunDevice) assignAddr() error {
|
||||||
link := newWGLink(w.name)
|
link := newWGLink(c.name)
|
||||||
|
|
||||||
//delete existing addresses
|
//delete existing addresses
|
||||||
list, err := netlink.AddrList(link, 0)
|
list, err := netlink.AddrList(link, 0)
|
||||||
@ -102,11 +106,11 @@ func (w *WGIface) assignAddr() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("adding address %s to interface: %s", w.address.String(), w.name)
|
log.Debugf("adding address %s to interface: %s", c.address.String(), c.name)
|
||||||
addr, _ := netlink.ParseAddr(w.address.String())
|
addr, _ := netlink.ParseAddr(c.address.String())
|
||||||
err = netlink.AddrAdd(link, addr)
|
err = netlink.AddrAdd(link, addr)
|
||||||
if os.IsExist(err) {
|
if os.IsExist(err) {
|
||||||
log.Infof("interface %s already has the address: %s", w.name, w.address.String())
|
log.Infof("interface %s already has the address: %s", c.name, c.address.String())
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
@ -1,5 +1,4 @@
|
|||||||
//go:build linux || darwin
|
//go:build (linux || darwin) && !android
|
||||||
// +build linux darwin
|
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
@ -14,24 +13,44 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
type tunDevice struct {
|
||||||
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
name string
|
||||||
return "", nil
|
address WGAddress
|
||||||
|
mtu int
|
||||||
|
netInterface NetInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the tunnel interface
|
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
|
||||||
func (w *WGIface) Close() error {
|
return &tunDevice{
|
||||||
w.mu.Lock()
|
name: name,
|
||||||
defer w.mu.Unlock()
|
address: address,
|
||||||
if w.netInterface == nil {
|
mtu: mtu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) UpdateAddr(address WGAddress) error {
|
||||||
|
c.address = address
|
||||||
|
return c.assignAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) WgAddress() WGAddress {
|
||||||
|
return c.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) DeviceName() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) Close() error {
|
||||||
|
if c.netInterface == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err := w.netInterface.Close()
|
err := c.netInterface.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sockPath := "/var/run/wireguard/" + w.name + ".sock"
|
sockPath := "/var/run/wireguard/" + c.name + ".sock"
|
||||||
if _, statErr := os.Stat(sockPath); statErr == nil {
|
if _, statErr := os.Stat(sockPath); statErr == nil {
|
||||||
statErr = os.Remove(sockPath)
|
statErr = os.Remove(sockPath)
|
||||||
if statErr != nil {
|
if statErr != nil {
|
||||||
@ -43,24 +62,23 @@ func (w *WGIface) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
||||||
func (w *WGIface) createWithUserspace() error {
|
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
||||||
|
tunIface, err := tun.CreateTUN(c.name, c.mtu)
|
||||||
tunIface, err := tun.CreateTUN(w.name, w.mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.netInterface = tunIface
|
|
||||||
|
|
||||||
// We need to create a wireguard-go device and listen to configuration requests
|
// We need to create a wireguard-go device and listen to configuration requests
|
||||||
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||||
err = tunDevice.Up()
|
err = tunDevice.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return tunIface, err
|
||||||
}
|
}
|
||||||
uapi, err := getUAPI(w.name)
|
|
||||||
|
// todo: after this line in case of error close the tunSock
|
||||||
|
uapi, err := c.getUAPI(c.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return tunIface, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -75,16 +93,11 @@ func (w *WGIface) createWithUserspace() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debugln("UAPI listener started")
|
log.Debugln("UAPI listener started")
|
||||||
|
return tunIface, nil
|
||||||
err = w.assignAddr()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUAPI returns a Listener
|
// getUAPI returns a Listener
|
||||||
func getUAPI(iface string) (net.Listener, error) {
|
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
|
||||||
tunSock, err := ipc.UAPIOpen(iface)
|
tunSock, err := ipc.UAPIOpen(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
93
iface/tun_windows.go
Normal file
93
iface/tun_windows.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.zx2c4.com/wireguard/windows/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tunDevice struct {
|
||||||
|
name string
|
||||||
|
address WGAddress
|
||||||
|
netInterface NetInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
|
||||||
|
return &tunDevice{name: name, address: address}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) Create() error {
|
||||||
|
var err error
|
||||||
|
c.netInterface, err = c.createAdapter()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.assignAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) UpdateAddr(address WGAddress) error {
|
||||||
|
c.address = address
|
||||||
|
return c.assignAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) WgAddress() WGAddress {
|
||||||
|
return c.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) DeviceName() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) Close() error {
|
||||||
|
if c.netInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.netInterface.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) getInterfaceGUIDString() (string, error) {
|
||||||
|
if c.netInterface == nil {
|
||||||
|
return "", fmt.Errorf("interface has not been initialized yet")
|
||||||
|
}
|
||||||
|
windowsDevice := c.netInterface.(*driver.Adapter)
|
||||||
|
luid := windowsDevice.LUID()
|
||||||
|
guid, err := luid.GUID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return guid.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tunDevice) createAdapter() (NetInterface, error) {
|
||||||
|
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
|
||||||
|
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error creating adapter: %w", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = adapter.SetAdapterState(driver.AdapterStateUp)
|
||||||
|
if err != nil {
|
||||||
|
return adapter, err
|
||||||
|
}
|
||||||
|
state, _ := adapter.LUID().GUID()
|
||||||
|
log.Debugln("device guid: ", state.String())
|
||||||
|
return adapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
|
func (c *tunDevice) assignAddr() error {
|
||||||
|
luid := c.netInterface.(*driver.Adapter).LUID()
|
||||||
|
|
||||||
|
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
|
||||||
|
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
114
iface/wg_configurer_android.go
Normal file
114
iface/wg_configurer_android.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errFuncNotImplemented = errors.New("function not implemented")
|
||||||
|
)
|
||||||
|
|
||||||
|
type wGConfigurer struct {
|
||||||
|
tunDevice *tunDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWGConfigurer(tunDevice *tunDevice) wGConfigurer {
|
||||||
|
return wGConfigurer{
|
||||||
|
tunDevice: tunDevice,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
|
||||||
|
log.Debugf("adding Wireguard private key")
|
||||||
|
key, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fwmark := 0
|
||||||
|
config := wgtypes.Config{
|
||||||
|
PrivateKey: &key,
|
||||||
|
ReplacePeers: true,
|
||||||
|
FirewallMark: &fwmark,
|
||||||
|
ListenPort: &port,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
|
//parse allowed ips
|
||||||
|
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
|
PresharedKey: preSharedKey,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) removePeer(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
UpdateOnly: true,
|
||||||
|
ReplaceAllowedIPs: false,
|
||||||
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
|
||||||
|
return errFuncNotImplemented
|
||||||
|
}
|
208
iface/wg_configurer_nonandroid.go
Normal file
208
iface/wg_configurer_nonandroid.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wGConfigurer struct {
|
||||||
|
deviceName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWGConfigurer(deviceName string) wGConfigurer {
|
||||||
|
return wGConfigurer{
|
||||||
|
deviceName: deviceName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
|
||||||
|
log.Debugf("adding Wireguard private key")
|
||||||
|
key, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fwmark := 0
|
||||||
|
config := wgtypes.Config{
|
||||||
|
PrivateKey: &key,
|
||||||
|
ReplacePeers: true,
|
||||||
|
FirewallMark: &fwmark,
|
||||||
|
ListenPort: &port,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.configure(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
|
//parse allowed ips
|
||||||
|
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
|
PresharedKey: preSharedKey,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
err = c.configure(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) removePeer(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
err = c.configure(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
UpdateOnly: true,
|
||||||
|
ReplaceAllowedIPs: false,
|
||||||
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
err = c.configure(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newAllowedIPs := existingPeer.AllowedIPs
|
||||||
|
|
||||||
|
for i, existingAllowedIP := range existingPeer.AllowedIPs {
|
||||||
|
if existingAllowedIP.String() == ipNet.String() {
|
||||||
|
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
UpdateOnly: true,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: newAllowedIPs,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
err = c.configure(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Peer{}, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Peer{}, err
|
||||||
|
}
|
||||||
|
for _, peer := range wgDevice.Peers {
|
||||||
|
if peer.PublicKey.String() == peerPubKey {
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wGConfigurer) configure(config wgtypes.Config) error {
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer wg.Close()
|
||||||
|
|
||||||
|
// validate if device with name exists
|
||||||
|
_, err = wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("got Wireguard device %s", c.deviceName)
|
||||||
|
|
||||||
|
return wg.ConfigureDevice(c.deviceName, config)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user