mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 09:50:47 +01:00
Merge branch 'main' into debug-bundle-with-networkmap
This commit is contained in:
commit
d1e655bd82
@ -164,9 +164,9 @@ func (a *Anonymizer) AnonymizeString(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
|
||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
||||
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
|
||||
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
|
@ -158,8 +158,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
|
||||
expect string
|
||||
}{
|
||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
|
||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
|
||||
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
|
||||
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
33
client/cmd/pprof.go
Normal file
33
client/cmd/pprof.go
Normal file
@ -0,0 +1,33 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
addr := pprofAddr()
|
||||
go pprof(addr)
|
||||
}
|
||||
|
||||
func pprofAddr() string {
|
||||
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||
if listenAddr == "" {
|
||||
return "localhost:6969"
|
||||
}
|
||||
|
||||
return listenAddr
|
||||
}
|
||||
|
||||
func pprof(listenAddr string) {
|
||||
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalf("Failed to start pprof: %v", err)
|
||||
}
|
||||
}
|
@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
|
@ -354,8 +354,17 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
||||
e.routeManager = routemanager.NewManager(
|
||||
e.ctx,
|
||||
e.config.WgPrivateKey.PublicKey().String(),
|
||||
e.config.DNSRouteInterval,
|
||||
e.wgInterface,
|
||||
e.statusRecorder,
|
||||
e.relayManager,
|
||||
initialRoutes,
|
||||
e.stateManager,
|
||||
)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
|
@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
nil)
|
||||
|
||||
wgIface := &iface.MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
@ -83,7 +83,6 @@ type Conn struct {
|
||||
signaler *Signaler
|
||||
relayManager *relayClient.Manager
|
||||
allowedIP net.IP
|
||||
allowedNet string
|
||||
handshaker *Handshaker
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
@ -111,7 +110,7 @@ type Conn struct {
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
|
||||
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return nil, err
|
||||
@ -129,7 +128,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIP: allowedIP,
|
||||
allowedNet: allowedNet.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
}
|
||||
@ -594,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ import (
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
@ -59,6 +59,7 @@ type DefaultManager struct {
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
@ -69,6 +70,7 @@ func NewManager(
|
||||
statusRecorder *peer.Status,
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
stateManager *statemanager.Manager,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
notifier := notifier.NewNotifier()
|
||||
@ -80,12 +82,12 @@ func NewManager(
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
relayMgr: relayMgr,
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: notifier,
|
||||
stateManager: stateManager,
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
@ -121,7 +123,7 @@ func NewManager(
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
m.routeSelector = m.initSelector()
|
||||
|
||||
log.Info("Routing setup complete")
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
var state *SelectorState
|
||||
m.stateManager.RegisterState(state)
|
||||
|
||||
// restore selector state if it exists
|
||||
if err := m.stateManager.LoadState(state); err != nil {
|
||||
log.Warnf("failed to load state: %v", err)
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
if state := m.stateManager.GetState(state); state != nil {
|
||||
if selector, ok := state.(*SelectorState); ok {
|
||||
return (*routeselector.RouteSelector)(selector)
|
||||
}
|
||||
|
||||
log.Warnf("failed to convert state with type %T to SelectorState", state)
|
||||
}
|
||||
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
|
@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
||||
|
||||
_, _, err = routeManager.Init(nil)
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop(nil)
|
||||
|
@ -21,7 +21,7 @@ type MockManager struct {
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
|
||||
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
|
@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
}
|
||||
|
||||
// LoadData loads the data from the existing counter
|
||||
// The passed counter should not be used any longer after calling this function.
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
existingCounter.mu.Lock()
|
||||
defer existingCounter.mu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var temp struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||
rm.refCountMap = temp.RefCountMap
|
||||
rm.idMap = temp.IDMap
|
||||
|
||||
if temp.RefCountMap == nil {
|
||||
temp.RefCountMap = map[Key]Ref[O]{}
|
||||
}
|
||||
if temp.IDMap == nil {
|
||||
temp.IDMap = map[string][]Key{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
19
client/internal/routemanager/state.go
Normal file
19
client/internal/routemanager/state.go
Normal file
@ -0,0 +1,19 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
)
|
||||
|
||||
type SelectorState routeselector.RouteSelector
|
||||
|
||||
func (s *SelectorState) Name() string {
|
||||
return "routeselector_state"
|
||||
}
|
||||
|
||||
func (s *SelectorState) MarshalJSON() ([]byte, error) {
|
||||
return (*routeselector.RouteSelector)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *SelectorState) UnmarshalJSON(data []byte) error {
|
||||
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
|
||||
}
|
@ -1,8 +1,10 @@
|
||||
package routeselector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
@ -12,6 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
mu sync.RWMutex
|
||||
selectedRoutes map[route.NetID]struct{}
|
||||
selectAll bool
|
||||
}
|
||||
@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {
|
||||
|
||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if !appendRoute {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
func (rs *RouteSelector) SelectAllRoutes() {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
rs.selectAll = true
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.selectAll {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.selectAll {
|
||||
return true
|
||||
}
|
||||
@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.selectAll {
|
||||
return maps.Clone(routes)
|
||||
}
|
||||
@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
}{
|
||||
SelectAll: rs.selectAll,
|
||||
SelectedRoutes: rs.selectedRoutes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
|
||||
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
// Check for null or empty JSON
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
rs.selectAll = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var temp struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rs.selectedRoutes = temp.SelectedRoutes
|
||||
rs.selectAll = temp.SelectAll
|
||||
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -22,9 +22,28 @@ import (
|
||||
// State interface defines the methods that all state types must implement
|
||||
type State interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// CleanableState interface extends State with cleanup capability
|
||||
type CleanableState interface {
|
||||
State
|
||||
Cleanup() error
|
||||
}
|
||||
|
||||
// RawState wraps raw JSON data for unregistered states
|
||||
type RawState struct {
|
||||
data json.RawMessage
|
||||
}
|
||||
|
||||
func (r *RawState) Name() string {
|
||||
return "" // This is a placeholder implementation
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
||||
func (r *RawState) MarshalJSON() ([]byte, error) {
|
||||
return r.data, nil
|
||||
}
|
||||
|
||||
// Manager handles the persistence and management of various states
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadState loads the existing state from the state file
|
||||
func (m *Manager) loadState() error {
|
||||
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
||||
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
|
||||
data, err := os.ReadFile(m.filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("state file does not exist")
|
||||
return nil
|
||||
return nil, nil // nolint:nilnil
|
||||
}
|
||||
return fmt.Errorf("read state file: %w", err)
|
||||
return nil, fmt.Errorf("read state file: %w", err)
|
||||
}
|
||||
|
||||
var rawStates map[string]json.RawMessage
|
||||
@ -228,37 +247,69 @@ func (m *Manager) loadState() error {
|
||||
} else {
|
||||
log.Info("State file deleted")
|
||||
}
|
||||
return fmt.Errorf("unmarshal states: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
return rawStates, nil
|
||||
}
|
||||
|
||||
for name, rawState := range rawStates {
|
||||
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||
stateType, ok := m.stateTypes[name]
|
||||
if !ok {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
||||
continue
|
||||
return nil, fmt.Errorf("state %s not registered", name)
|
||||
}
|
||||
|
||||
if string(rawState) == "null" {
|
||||
continue
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
statePtr := reflect.New(stateType).Interface().(State)
|
||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
||||
continue
|
||||
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.states[name] = statePtr
|
||||
return statePtr, nil
|
||||
}
|
||||
|
||||
// LoadState loads a specific state from the state file
|
||||
func (m *Manager) LoadState(state State) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
rawStates, err := m.loadStateFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := state.Name()
|
||||
rawState, exists := rawStates[name]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.states[name] = loadedState
|
||||
if loadedState != nil {
|
||||
log.Debugf("loaded state: %s", name)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
||||
// If the cleanup is successful, the state is marked for deletion.
|
||||
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
||||
// Unregistered states are preserved in their original state.
|
||||
func (m *Manager) PerformCleanup() error {
|
||||
if m == nil {
|
||||
return nil
|
||||
@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if err := m.loadState(); err != nil {
|
||||
// Load raw states from file
|
||||
rawStates, err := m.loadStateFile()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to load state during cleanup: %v", err)
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for name, state := range m.states {
|
||||
if state == nil {
|
||||
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
||||
|
||||
// Process each state in the file
|
||||
for name, rawState := range rawStates {
|
||||
// For unregistered states, preserve the raw JSON
|
||||
if _, registered := m.stateTypes[name]; !registered {
|
||||
m.states[name] = &RawState{data: rawState}
|
||||
continue
|
||||
}
|
||||
|
||||
// Load the registered state
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if loadedState == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if state supports cleanup
|
||||
cleanableState, isCleanable := loadedState.(CleanableState)
|
||||
if !isCleanable {
|
||||
// If it doesn't support cleanup, keep it as-is
|
||||
m.states[name] = loadedState
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform cleanup for cleanable states
|
||||
log.Infof("client was not shut down properly, cleaning up %s", name)
|
||||
if err := state.Cleanup(); err != nil {
|
||||
if err := cleanableState.Cleanup(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
||||
// On cleanup error, preserve the state
|
||||
m.states[name] = loadedState
|
||||
} else {
|
||||
// mark for deletion on cleanup success
|
||||
// Successfully cleaned up - mark for deletion
|
||||
m.states[name] = nil
|
||||
m.dirty[name] = struct{}{}
|
||||
}
|
||||
|
@ -7,12 +7,15 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@ -23,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
@ -31,11 +35,14 @@ const readmeContent = `Netbird debug bundle
|
||||
This debug bundle contains the following files:
|
||||
|
||||
status.txt: Anonymized status information of the NetBird client.
|
||||
client.log: Most recent, anonymized log file of the NetBird client.
|
||||
client.log: Most recent, anonymized client log file of the NetBird client.
|
||||
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
|
||||
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Anonymized system routes, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@ -65,6 +72,19 @@ The network_map.json file contains the following anonymized information:
|
||||
|
||||
SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
|
||||
|
||||
State File
|
||||
The state.json file contains anonymized internal state information of the NetBird client, including:
|
||||
- DNS settings and configuration
|
||||
- Firewall rules
|
||||
- Exclusion routes
|
||||
- Route selection
|
||||
- Other internal states that may be present
|
||||
|
||||
The state file follows the same anonymization rules as other files:
|
||||
- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
|
||||
- Domain names are consistently anonymized
|
||||
- Technical identifiers and non-sensitive data remain unchanged
|
||||
|
||||
Routes
|
||||
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
|
||||
|
||||
@ -88,6 +108,12 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
`
|
||||
|
||||
const (
|
||||
clientLogFile = "client.log"
|
||||
errorLogFile = "netbird.err"
|
||||
stdoutLogFile = "netbird.out"
|
||||
)
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@ -152,6 +178,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
return fmt.Errorf("add network map: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addStateFile(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add state file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
@ -302,14 +332,73 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
|
||||
logFile, err := os.Open(s.logFile)
|
||||
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file: %w", err)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read state file: %w", err)
|
||||
}
|
||||
|
||||
if req.GetAnonymize() {
|
||||
var rawStates map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||
return fmt.Errorf("unmarshal states: %w", err)
|
||||
}
|
||||
|
||||
if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
|
||||
return fmt.Errorf("anonymize state file: %w", err)
|
||||
}
|
||||
|
||||
bs, err := json.MarshalIndent(rawStates, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
data = bs
|
||||
}
|
||||
|
||||
if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
|
||||
return fmt.Errorf("add state file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
logDir := filepath.Dir(s.logFile)
|
||||
|
||||
if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
errLogPath := filepath.Join(logDir, errorLogFile)
|
||||
if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil {
|
||||
log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
|
||||
}
|
||||
|
||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||
if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
|
||||
log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addSingleLogfile adds a single log file to the archive
|
||||
func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
logFile, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file %s: %w", targetName, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("Failed to close original log file: %v", err)
|
||||
log.Errorf("Failed to close log file %s: %v", targetName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -318,12 +407,13 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, anonymizer)
|
||||
go anonymizeLog(logFile, writer, anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
|
||||
return fmt.Errorf("add log file to zip: %w", err)
|
||||
|
||||
if err := addFileToZip(archive, logReader, targetName); err != nil {
|
||||
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -555,6 +645,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
|
||||
defer func() {
|
||||
// always nil
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
|
||||
anonymizedIPs := make([]string, len(ips))
|
||||
for i, ip := range ips {
|
||||
@ -752,3 +862,77 @@ func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *an
|
||||
rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
|
||||
for name, rawState := range *rawStates {
|
||||
if string(rawState) == "null" {
|
||||
continue
|
||||
}
|
||||
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(rawState, &state); err != nil {
|
||||
return fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
state = anonymizeValue(state, anonymizer).(map[string]any)
|
||||
|
||||
bs, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
(*rawStates)[name] = bs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return anonymizeString(v, anonymizer)
|
||||
case map[string]any:
|
||||
return anonymizeMap(v, anonymizer)
|
||||
case []any:
|
||||
return anonymizeSlice(v, anonymizer)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
|
||||
if prefix, err := netip.ParsePrefix(v); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
if ip, err := netip.ParseAddr(v); err == nil {
|
||||
return anonymizer.AnonymizeIP(ip).String()
|
||||
}
|
||||
return anonymizer.AnonymizeString(v)
|
||||
}
|
||||
|
||||
func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
|
||||
result := make(map[string]any, len(v))
|
||||
for key, val := range v {
|
||||
newKey := anonymizeMapKey(key, anonymizer)
|
||||
result[newKey] = anonymizeValue(val, anonymizer)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
|
||||
if prefix, err := netip.ParsePrefix(key); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
if ip, err := netip.ParseAddr(key); err == nil {
|
||||
return anonymizer.AnonymizeIP(ip).String()
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
|
||||
for i, val := range v {
|
||||
v[i] = anonymizeValue(val, anonymizer)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
@ -1,16 +1,271 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func TestAnonymizeStateFile(t *testing.T) {
|
||||
testState := map[string]json.RawMessage{
|
||||
"null_state": json.RawMessage("null"),
|
||||
"test_state": mustMarshal(map[string]any{
|
||||
// Test simple fields
|
||||
"public_ip": "203.0.113.1",
|
||||
"private_ip": "192.168.1.1",
|
||||
"protected_ip": "100.64.0.1",
|
||||
"well_known_ip": "8.8.8.8",
|
||||
"ipv6_addr": "2001:db8::1",
|
||||
"private_ipv6": "fd00::1",
|
||||
"domain": "test.example.com",
|
||||
"uri": "stun:stun.example.com:3478",
|
||||
"uri_with_ip": "turn:203.0.113.1:3478",
|
||||
"netbird_domain": "device.netbird.cloud",
|
||||
|
||||
// Test CIDR ranges
|
||||
"public_cidr": "203.0.113.0/24",
|
||||
"private_cidr": "192.168.0.0/16",
|
||||
"protected_cidr": "100.64.0.0/10",
|
||||
"ipv6_cidr": "2001:db8::/32",
|
||||
"private_ipv6_cidr": "fd00::/8",
|
||||
|
||||
// Test nested structures
|
||||
"nested": map[string]any{
|
||||
"ip": "203.0.113.2",
|
||||
"domain": "nested.example.com",
|
||||
"more_nest": map[string]any{
|
||||
"ip": "203.0.113.3",
|
||||
"domain": "deep.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test arrays
|
||||
"string_array": []any{
|
||||
"203.0.113.4",
|
||||
"test1.example.com",
|
||||
"test2.example.com",
|
||||
},
|
||||
"object_array": []any{
|
||||
map[string]any{
|
||||
"ip": "203.0.113.5",
|
||||
"domain": "array1.example.com",
|
||||
},
|
||||
map[string]any{
|
||||
"ip": "203.0.113.6",
|
||||
"domain": "array2.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test multiple occurrences of same value
|
||||
"duplicate_ip": "203.0.113.1", // Same as public_ip
|
||||
"duplicate_domain": "test.example.com", // Same as domain
|
||||
|
||||
// Test URIs with various schemes
|
||||
"stun_uri": "stun:stun.example.com:3478",
|
||||
"turns_uri": "turns:turns.example.com:5349",
|
||||
"http_uri": "http://web.example.com:80",
|
||||
"https_uri": "https://secure.example.com:443",
|
||||
|
||||
// Test strings that might look like IPs but aren't
|
||||
"not_ip": "300.300.300.300",
|
||||
"partial_ip": "192.168",
|
||||
"ip_like_string": "1234.5678",
|
||||
|
||||
// Test mixed content strings
|
||||
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
|
||||
|
||||
// Test empty and special values
|
||||
"empty_string": "",
|
||||
"null_value": nil,
|
||||
"numeric_value": 42,
|
||||
"boolean_value": true,
|
||||
}),
|
||||
"route_state": mustMarshal(map[string]any{
|
||||
"routes": []any{
|
||||
map[string]any{
|
||||
"network": "203.0.113.0/24",
|
||||
"gateway": "203.0.113.1",
|
||||
"domains": []any{
|
||||
"route1.example.com",
|
||||
"route2.example.com",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"network": "2001:db8::/32",
|
||||
"gateway": "2001:db8::1",
|
||||
"domains": []any{
|
||||
"route3.example.com",
|
||||
"route4.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
// Test map with IP/CIDR keys
|
||||
"refCountMap": map[string]any{
|
||||
"203.0.113.1/32": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"2001:db8::1/128": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "fe80::1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Pre-seed the domains we need to verify in the test assertions
|
||||
anonymizer.AnonymizeDomain("test.example.com")
|
||||
anonymizer.AnonymizeDomain("nested.example.com")
|
||||
anonymizer.AnonymizeDomain("deep.example.com")
|
||||
anonymizer.AnonymizeDomain("array1.example.com")
|
||||
|
||||
err := anonymizeStateFile(&testState, anonymizer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Helper function to unmarshal and get nested values
|
||||
var state map[string]any
|
||||
err = json.Unmarshal(testState["test_state"], &state)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test null state remains unchanged
|
||||
require.Equal(t, "null", string(testState["null_state"]))
|
||||
|
||||
// Basic assertions
|
||||
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
|
||||
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
|
||||
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
||||
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
||||
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
||||
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
|
||||
assert.NotEqual(t, "test.example.com", state["domain"])
|
||||
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
||||
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
||||
|
||||
// CIDR ranges
|
||||
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
|
||||
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
|
||||
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
|
||||
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
|
||||
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
|
||||
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
|
||||
|
||||
// Nested structures
|
||||
nested := state["nested"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.2", nested["ip"])
|
||||
assert.NotEqual(t, "nested.example.com", nested["domain"])
|
||||
moreNest := nested["more_nest"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
|
||||
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
|
||||
|
||||
// Arrays
|
||||
strArray := state["string_array"].([]any)
|
||||
assert.NotEqual(t, "203.0.113.4", strArray[0])
|
||||
assert.NotEqual(t, "test1.example.com", strArray[1])
|
||||
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
|
||||
|
||||
objArray := state["object_array"].([]any)
|
||||
firstObj := objArray[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
|
||||
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
|
||||
|
||||
// Duplicate values should be anonymized consistently
|
||||
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
|
||||
assert.Equal(t, state["domain"], state["duplicate_domain"])
|
||||
|
||||
// URIs
|
||||
assert.NotContains(t, state["stun_uri"], "stun.example.com")
|
||||
assert.NotContains(t, state["turns_uri"], "turns.example.com")
|
||||
assert.NotContains(t, state["http_uri"], "web.example.com")
|
||||
assert.NotContains(t, state["https_uri"], "secure.example.com")
|
||||
|
||||
// Non-IP strings should remain unchanged
|
||||
assert.Equal(t, "300.300.300.300", state["not_ip"])
|
||||
assert.Equal(t, "192.168", state["partial_ip"])
|
||||
assert.Equal(t, "1234.5678", state["ip_like_string"])
|
||||
|
||||
// Mixed content should have IPs and domains replaced
|
||||
mixedContent := state["mixed_content"].(string)
|
||||
assert.NotContains(t, mixedContent, "203.0.113.1")
|
||||
assert.NotContains(t, mixedContent, "test.example.com")
|
||||
assert.Contains(t, mixedContent, "Server at ")
|
||||
assert.Contains(t, mixedContent, " on port 80")
|
||||
|
||||
// Special values should remain unchanged
|
||||
assert.Equal(t, "", state["empty_string"])
|
||||
assert.Nil(t, state["null_value"])
|
||||
assert.Equal(t, float64(42), state["numeric_value"])
|
||||
assert.Equal(t, true, state["boolean_value"])
|
||||
|
||||
// Check route state
|
||||
var routeState map[string]any
|
||||
err = json.Unmarshal(testState["route_state"], &routeState)
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := routeState["routes"].([]any)
|
||||
route1 := routes[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
|
||||
assert.Contains(t, route1["network"], "/24")
|
||||
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
|
||||
domains := route1["domains"].([]any)
|
||||
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
|
||||
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
|
||||
|
||||
// Check map keys are anonymized
|
||||
refCountMap := routeState["refCountMap"].(map[string]any)
|
||||
hasPublicIPKey := false
|
||||
hasIPv6Key := false
|
||||
hasPrivateIPKey := false
|
||||
for key := range refCountMap {
|
||||
if strings.Contains(key, "203.0.113.1") {
|
||||
hasPublicIPKey = true
|
||||
}
|
||||
if strings.Contains(key, "2001:db8::1") {
|
||||
hasIPv6Key = true
|
||||
}
|
||||
if key == "10.0.0.1/32" {
|
||||
hasPrivateIPKey = true
|
||||
}
|
||||
}
|
||||
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
|
||||
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
|
||||
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
|
||||
}
|
||||
|
||||
func mustMarshal(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
|
@ -61,6 +61,14 @@ type Info struct {
|
||||
Files []File // for posture checks
|
||||
}
|
||||
|
||||
// StaticInfo is an object that contains machine information that does not change
|
||||
type StaticInfo struct {
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
|
@ -10,13 +10,12 @@ import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: release,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build !android
|
||||
// +build !android
|
||||
|
||||
package system
|
||||
|
||||
@ -16,30 +15,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zcalusic/sysinfo"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type SysInfoGetter interface {
|
||||
GetSysInfo() SysInfo
|
||||
}
|
||||
|
||||
type SysInfoWrapper struct {
|
||||
si sysinfo.SysInfo
|
||||
}
|
||||
|
||||
func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
||||
s.si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: s.si.Chassis.Serial,
|
||||
ProductSerial: s.si.Product.Serial,
|
||||
BoardSerial: s.si.Board.Serial,
|
||||
ProductName: s.si.Product.Name,
|
||||
BoardName: s.si.Board.Name,
|
||||
ProductVendor: s.si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
var (
|
||||
// it is override in tests
|
||||
getSystemInfo = defaultSysInfoImplementation
|
||||
)
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
}
|
||||
|
||||
return gio
|
||||
@ -108,9 +88,9 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo(si SysInfo) (string, string, string) {
|
||||
func sysInfo() (string, string, string) {
|
||||
isascii := regexp.MustCompile("^[[:ascii:]]+$")
|
||||
|
||||
si := getSystemInfo()
|
||||
serials := []string{si.ChassisSerial, si.ProductSerial}
|
||||
serial := ""
|
||||
|
||||
@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) {
|
||||
}
|
||||
return serial, name, manufacturer
|
||||
}
|
||||
|
||||
func defaultSysInfoImplementation() SysInfo {
|
||||
si := sysinfo.SysInfo{}
|
||||
si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: si.Chassis.Serial,
|
||||
ProductSerial: si.Product.Serial,
|
||||
BoardSerial: si.Board.Serial,
|
||||
ProductName: si.Product.Name,
|
||||
BoardName: si.Board.Name,
|
||||
ProductVendor: si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
|
@ -6,13 +6,12 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, err := sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
prodName, err := sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err := sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: buildVersion,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var err error
|
||||
serialNumber, err = sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
productName, err = sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err = sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
return serialNumber, productName, manufacturer
|
||||
}
|
||||
|
||||
func getOSNameAndVersion() (string, string) {
|
||||
var dst []Win32_OperatingSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
|
46
client/system/static_info.go
Normal file
46
client/system/static_info.go
Normal file
@ -0,0 +1,46 @@
|
||||
//go:build (linux && !android) || windows || (darwin && !ios)
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
)
|
||||
|
||||
var (
|
||||
staticInfo StaticInfo
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
_ = updateStaticInfo()
|
||||
}()
|
||||
}
|
||||
|
||||
func updateStaticInfo() StaticInfo {
|
||||
once.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
return staticInfo
|
||||
}
|
@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
|
||||
getSystemInfo = func() SysInfo {
|
||||
return tt.sysInfo
|
||||
}
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo()
|
||||
if gotSerialNum != tt.wantSerialNum {
|
||||
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
|
||||
}
|
||||
|
@ -530,7 +530,7 @@ renderCaddyfile() {
|
||||
{
|
||||
debug
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c
|
||||
protocols h1 h2c h3
|
||||
}
|
||||
}
|
||||
|
||||
@ -788,6 +788,7 @@ services:
|
||||
networks: [ netbird ]
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
|
@ -417,7 +417,20 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string,
|
||||
continue
|
||||
}
|
||||
|
||||
policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups)
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
@ -428,14 +441,58 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string,
|
||||
continue
|
||||
}
|
||||
|
||||
distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -1486,6 +1487,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
peerBIp = "100.65.80.39"
|
||||
peerCIp = "100.65.254.139"
|
||||
peerHIp = "100.65.29.55"
|
||||
peerJIp = "100.65.29.65"
|
||||
peerKIp = "100.65.29.66"
|
||||
)
|
||||
|
||||
account := &Account{
|
||||
@ -1541,6 +1544,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
IP: net.ParseIP(peerHIp),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
"peerJ": {
|
||||
ID: "peerJ",
|
||||
IP: net.ParseIP(peerJIp),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
"peerK": {
|
||||
ID: "peerK",
|
||||
IP: net.ParseIP(peerKIp),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"routingPeer1": {
|
||||
@ -1567,6 +1580,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
Name: "Route2",
|
||||
Peers: []string{},
|
||||
},
|
||||
"route4": {
|
||||
ID: "route4",
|
||||
Name: "route4",
|
||||
Peers: []string{},
|
||||
},
|
||||
"finance": {
|
||||
ID: "finance",
|
||||
Name: "Finance",
|
||||
@ -1584,6 +1602,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
"peerB",
|
||||
},
|
||||
},
|
||||
"qa": {
|
||||
ID: "qa",
|
||||
Name: "QA",
|
||||
Peers: []string{
|
||||
"peerJ",
|
||||
"peerK",
|
||||
},
|
||||
},
|
||||
"restrictQA": {
|
||||
ID: "restrictQA",
|
||||
Name: "restrictQA",
|
||||
Peers: []string{
|
||||
"peerJ",
|
||||
},
|
||||
},
|
||||
"unrestrictedQA": {
|
||||
ID: "unrestrictedQA",
|
||||
Name: "unrestrictedQA",
|
||||
Peers: []string{
|
||||
"peerK",
|
||||
},
|
||||
},
|
||||
"contractors": {
|
||||
ID: "contractors",
|
||||
Name: "Contractors",
|
||||
@ -1631,6 +1671,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
Groups: []string{"contractors"},
|
||||
AccessControlGroups: []string{},
|
||||
},
|
||||
"route4": {
|
||||
ID: "route4",
|
||||
Network: netip.MustParsePrefix("192.168.10.0/16"),
|
||||
NetID: "route4",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"routingPeer1"},
|
||||
Description: "Route4",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"qa"},
|
||||
AccessControlGroups: []string{"route4"},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
{
|
||||
@ -1685,6 +1738,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "RuleRoute4",
|
||||
Name: "RuleRoute4",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute4",
|
||||
Name: "RuleRoute4",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Ports: []string{"80"},
|
||||
Sources: []string{
|
||||
"restrictQA",
|
||||
},
|
||||
Destinations: []string{
|
||||
"route4",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "RuleRoute5",
|
||||
Name: "RuleRoute5",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute5",
|
||||
Name: "RuleRoute5",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"unrestrictedQA",
|
||||
},
|
||||
Destinations: []string{
|
||||
"route4",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1709,7 +1805,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
|
||||
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
||||
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 2)
|
||||
assert.Len(t, routesFirewallRules, 4)
|
||||
|
||||
expectedRoutesFirewallRules := []*RouteFirewallRule{
|
||||
{
|
||||
@ -1735,12 +1831,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
Port: 320,
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||
additionalFirewallRule := []*RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "tcp",
|
||||
Port: 80,
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "all",
|
||||
},
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
|
||||
|
||||
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 2)
|
||||
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerE is a single routing peer for route 2 and route 3
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||
@ -1769,7 +1885,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
IsDynamic: true,
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
@ -1778,6 +1894,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
// orderList is a helper function to sort a list of strings
|
||||
func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule {
|
||||
for _, rule := range ruleList {
|
||||
sort.Strings(rule.SourceRanges)
|
||||
}
|
||||
return ruleList
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
Loading…
Reference in New Issue
Block a user