Add routes and dns servers to status command (#1680)

* Add routes (client and server) to status command
* Add DNS servers to status output
This commit is contained in:
Viktor Liu
2024-03-12 19:06:16 +01:00
committed by GitHub
parent ba33572ec9
commit 4a1aee1ae0
20 changed files with 723 additions and 180 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"strings"
"sync"
"github.com/miekg/dns"
@@ -11,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -59,6 +61,8 @@ type DefaultServer struct {
// make sense on mobile only
searchDomainNotifier *notifier
iosDnsManager IosDnsManager
statusRecorder *peer.Status
}
type handlerWithStop interface {
@@ -73,7 +77,12 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -90,13 +99,20 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService), nil
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
func NewDefaultServerPermanentUpstream(
ctx context.Context,
wgInterface WGIface,
hostsDnsList []string,
config nbdns.Config,
listener listener.NetworkChangeListener,
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
@@ -108,13 +124,18 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
}
// NewDefaultServerIos returns a new dns server. It optimized for ios
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
@@ -124,7 +145,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
wgInterface: wgInterface,
statusRecorder: statusRecorder,
}
return defaultServer
@@ -299,6 +321,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
s.updateNSGroupStates(update.NameServerGroups)
return nil
}
@@ -338,7 +362,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
)
if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
}
@@ -460,14 +490,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup,
handler dns.Handler,
) (deactivate func(), reactivate func()) {
) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int
deactivate = func() {
deactivate = func(err error) {
s.mux.Lock()
defer s.mux.Unlock()
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("temporary deactivate nameservers group due timeout")
l.Info("Temporarily deactivating nameservers group due to timeout")
removeIndex = make(map[string]int)
for _, domain := range nsGroup.Domains {
@@ -486,8 +516,11 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
s.updateNSState(nsGroup, err, false)
}
reactivate = func() {
s.mux.Lock()
@@ -510,12 +543,20 @@ func (s *DefaultServer) upstreamCallbacks(
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
s.updateNSState(nsGroup, nil, true)
}
return
}
func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
)
if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
@@ -535,7 +576,50 @@ func (s *DefaultServer) addHostRootZone() {
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
}
handler.deactivate = func() {}
handler.deactivate = func(error) {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState
for _, group := range groups {
var servers []string
for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
state := peer.NSGroupState{
ID: generateGroupKey(group),
Servers: servers,
Domains: group.Domains,
// The probe will determine the state, default enabled
Enabled: true,
Error: nil,
}
states = append(states, state)
}
s.statusRecorder.UpdateDNSStates(states)
}
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
states := s.statusRecorder.GetDNSStates()
id := generateGroupKey(nsGroup)
for i, state := range states {
if state.ID == id {
states[i].Enabled = enabled
states[i].Error = err
break
}
}
s.statusRecorder.UpdateDNSStates(states)
}
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string
for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
}

View File

@@ -15,6 +15,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
@@ -274,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
if err != nil {
t.Fatal(err)
}
@@ -375,7 +376,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -470,7 +471,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
if err != nil {
t.Fatalf("%v", err)
}
@@ -541,6 +542,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
{false, "domain2", false},
},
},
statusRecorder: &peer.Status{},
}
var domainsUpdate string
@@ -563,7 +565,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
},
}, nil)
deactivate()
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
@@ -601,7 +603,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -625,7 +627,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -717,7 +719,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)

View File

@@ -11,8 +11,11 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
@@ -45,12 +48,13 @@ type upstreamResolverBase struct {
reactivatePeriod time.Duration
upstreamTimeout time.Duration
deactivate func()
reactivate func()
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
}
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
ctx, cancel := context.WithCancel(parentCTX)
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
@@ -58,6 +62,7 @@ func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
}
}
@@ -68,7 +73,10 @@ func (u *upstreamResolverBase) stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
var err error
defer func() {
u.checkUpstreamFails(err)
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
@@ -81,7 +89,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
for _, upstream := range u.upstreamServers {
var rm *dns.Msg
var t time.Duration
var err error
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
@@ -132,7 +139,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails() {
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock()
defer u.mutex.Unlock()
@@ -146,7 +153,7 @@ func (u *upstreamResolverBase) checkUpstreamFails() {
default:
}
u.disable()
u.disable(err)
}
// probeAvailability tests all upstream servers simultaneously and
@@ -165,13 +172,16 @@ func (u *upstreamResolverBase) probeAvailability() {
var mu sync.Mutex
var wg sync.WaitGroup
var errors *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func() {
defer wg.Done()
if err := u.testNameserver(upstream); err != nil {
err := u.testNameserver(upstream)
if err != nil {
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
@@ -186,7 +196,7 @@ func (u *upstreamResolverBase) probeAvailability() {
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable()
u.disable(errors.ErrorOrNil())
}
}
@@ -245,15 +255,15 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable() {
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate()
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolverIOS struct {
@@ -20,8 +22,14 @@ type upstreamResolverIOS struct {
iIndex int
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip net.IP,
net *net.IPNet,
statusRecorder *peer.Status,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {

View File

@@ -8,14 +8,22 @@ import (
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolverNonIOS struct {
*upstreamResolverBase
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase,
}

View File

@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil)
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
@@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
}
failed := false
resolver.deactivate = func() {
resolver.deactivate = func(error) {
failed = true
}

View File

@@ -1188,14 +1188,21 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
e.mobileDep.HostDNSAddresses,
*dnsConfig,
e.mobileDep.NetworkChangeListener,
e.statusRecorder,
)
go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
if err != nil {
return nil, nil, err
}

View File

@@ -29,6 +29,7 @@ type State struct {
BytesTx int64
BytesRx int64
RosenpassEnabled bool
Routes map[string]struct{}
}
// LocalPeerState contains the latest state of the local peer
@@ -37,6 +38,7 @@ type LocalPeerState struct {
PubKey string
KernelInterface bool
FQDN string
Routes map[string]struct{}
}
// SignalState contains the latest state of a signal connection
@@ -59,6 +61,16 @@ type RosenpassState struct {
Permissive bool
}
// NSGroupState represents the status of a DNS server group, including associated domains,
// whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct {
ID string
Servers []string
Domains []string
Enabled bool
Error error
}
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
Peers []State
@@ -67,6 +79,7 @@ type FullStatus struct {
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
}
// Status holds a state of peers, signal, management connections and relays
@@ -86,6 +99,7 @@ type Status struct {
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -174,6 +188,10 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
}
skipNotification := shouldSkipNotify(receivedState, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -278,6 +296,13 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
return ch
}
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer
}
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()
@@ -364,6 +389,12 @@ func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.relayStates = relayResults
}
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.mux.Lock()
defer d.mux.Unlock()
d.nsGroupStates = dnsStates
}
func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{
d.rosenpassEnabled,
@@ -409,6 +440,10 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
@@ -420,6 +455,7 @@ func (d *Status) GetFullStatus() FullStatus {
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
for _, status := range d.peers {

View File

@@ -160,6 +160,12 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
if err != nil {
return err
}
delete(state.Routes, c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
@@ -225,6 +231,20 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
c.chosenRoute = c.routes[chosen]
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
if err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",

View File

@@ -58,7 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
return err
}

View File

@@ -7,9 +7,10 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -10,24 +10,27 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type defaultServerRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
statusRecorder *peer.Status
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
statusRecorder: statusRecorder,
}, nil
}
@@ -88,6 +91,11 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
return err
}
delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -105,6 +113,14 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
return err
}
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[route.Network.String()] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -117,6 +133,10 @@ func (m *defaultServerRouter) cleanUp() {
if err != nil {
log.Warnf("failed to remove clean up route: %s", r.ID)
}
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
}