mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 09:51:16 +01:00
Initial approach
This commit is contained in:
parent
e40a29ba17
commit
13537e6640
@ -33,7 +33,8 @@ type Server interface {
|
|||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
DnsPort() int
|
||||||
|
UpdateDNSServer(update nbdns.Config, hasDNSRoute bool) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
@ -51,7 +52,6 @@ type DefaultServer struct {
|
|||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
|
||||||
@ -183,6 +183,11 @@ func (s *DefaultServer) DnsIP() string {
|
|||||||
return s.service.RuntimeIP()
|
return s.service.RuntimeIP()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DnsPort() int {
|
||||||
|
// Todo: review what will be if the service is not running yet
|
||||||
|
return s.service.RuntimePort()
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@ -215,16 +220,12 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (s *DefaultServer) UpdateDNSServer(update nbdns.Config, hasDNSRoute bool) error {
|
||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
log.Infof("not updating DNS server as context is closed")
|
log.Infof("not updating DNS server as context is closed")
|
||||||
return s.ctx.Err()
|
return s.ctx.Err()
|
||||||
default:
|
default:
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@ -244,17 +245,14 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
|
|
||||||
if s.previousConfigHash == hash {
|
if s.previousConfigHash == hash {
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
if err := s.applyConfiguration(update, hasDNSRoute); err != nil {
|
||||||
return fmt.Errorf("apply configuration: %w", err)
|
return fmt.Errorf("apply configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,15 +286,18 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config, hasDNSRoute bool) error {
|
||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if update.ServiceEnable {
|
if update.ServiceEnable || hasDNSRoute {
|
||||||
_ = s.service.Listen()
|
_ = s.service.Listen()
|
||||||
} else if !s.permanent {
|
} else if !s.permanent {
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trace the dns configuration update
|
||||||
|
log.Infof("---- dns server listen address: %s:%d", s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
@ -44,7 +44,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
@ -802,14 +801,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
// todo keep the state because of the serial or eliminate the serial usage from dns and route mgr
|
||||||
if protoRoutes == nil {
|
networkMapMgr := networkMapHandler{
|
||||||
protoRoutes = []*mgmProto.Route{}
|
DNSServer: e.dnsServer,
|
||||||
|
RouteManager: e.routeManager,
|
||||||
}
|
}
|
||||||
|
if err := networkMapMgr.update(serial, networkMap); err != nil {
|
||||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
log.Warnf("failed to update apply network map: %v", err)
|
||||||
if err != nil {
|
// todo: consider to return here with error
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
e.clientRoutesMu.Lock()
|
||||||
@ -858,16 +857,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protoDNSConfig := networkMap.GetDNSConfig()
|
|
||||||
if protoDNSConfig == nil {
|
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||||
@ -877,76 +866,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|
||||||
routes := make([]*route.Route, 0)
|
|
||||||
for _, protoRoute := range protoRoutes {
|
|
||||||
var prefix netip.Prefix
|
|
||||||
if len(protoRoute.Domains) == 0 {
|
|
||||||
var err error
|
|
||||||
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
|
||||||
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
convertedRoute := &route.Route{
|
|
||||||
ID: route.ID(protoRoute.ID),
|
|
||||||
Network: prefix,
|
|
||||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
|
||||||
NetID: route.NetID(protoRoute.NetID),
|
|
||||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
|
||||||
Peer: protoRoute.Peer,
|
|
||||||
Metric: int(protoRoute.Metric),
|
|
||||||
Masquerade: protoRoute.Masquerade,
|
|
||||||
KeepRoute: protoRoute.KeepRoute,
|
|
||||||
}
|
|
||||||
routes = append(routes, convertedRoute)
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
|
||||||
dnsUpdate := nbdns.Config{
|
|
||||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
|
||||||
CustomZones: make([]nbdns.CustomZone, 0),
|
|
||||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
|
||||||
dnsZone := nbdns.CustomZone{
|
|
||||||
Domain: zone.GetDomain(),
|
|
||||||
}
|
|
||||||
for _, record := range zone.Records {
|
|
||||||
dnsRecord := nbdns.SimpleRecord{
|
|
||||||
Name: record.GetName(),
|
|
||||||
Type: int(record.GetType()),
|
|
||||||
Class: record.GetClass(),
|
|
||||||
TTL: int(record.GetTTL()),
|
|
||||||
RData: record.GetRData(),
|
|
||||||
}
|
|
||||||
dnsZone.Records = append(dnsZone.Records, dnsRecord)
|
|
||||||
}
|
|
||||||
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
|
|
||||||
dnsNSGroup := &nbdns.NameServerGroup{
|
|
||||||
Primary: nsGroup.GetPrimary(),
|
|
||||||
Domains: nsGroup.GetDomains(),
|
|
||||||
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
|
|
||||||
}
|
|
||||||
for _, ns := range nsGroup.GetNameServers() {
|
|
||||||
dnsNS := nbdns.NameServer{
|
|
||||||
IP: netip.MustParseAddr(ns.GetIP()),
|
|
||||||
NSType: nbdns.NameServerType(ns.GetNSType()),
|
|
||||||
Port: int(ns.GetPort()),
|
|
||||||
}
|
|
||||||
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
|
|
||||||
}
|
|
||||||
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
|
|
||||||
}
|
|
||||||
return dnsUpdate
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||||
replacement := make([]peer.State, len(offlinePeers))
|
replacement := make([]peer.State, len(offlinePeers))
|
||||||
for i, offlinePeer := range offlinePeers {
|
for i, offlinePeer := range offlinePeers {
|
||||||
@ -1235,7 +1154,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
_, routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||||
return routes, &dnsCfg, nil
|
return routes, &dnsCfg, nil
|
||||||
}
|
}
|
||||||
|
@ -509,7 +509,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
expectedSerial uint64
|
expectedSerial uint64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Routes Config Should Be Passed To Manager",
|
name: "Routes Config Should Be Passed To networkMapHandler",
|
||||||
networkMap: &mgmtProto.NetworkMap{
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
Serial: 1,
|
Serial: 1,
|
||||||
PeerConfig: nil,
|
PeerConfig: nil,
|
||||||
|
171
client/internal/netmap_handler.go
Normal file
171
client/internal/netmap_handler.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type networkMapHandler struct {
|
||||||
|
DNSServer dns.Server
|
||||||
|
RouteManager routemanager.Manager
|
||||||
|
Firewall firewall.Manager
|
||||||
|
|
||||||
|
updateSerial uint64
|
||||||
|
dnsRules []firewall.Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *networkMapHandler) update(serial uint64, networkMap *mgmProto.NetworkMap) error {
|
||||||
|
if serial < h.updateSerial {
|
||||||
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
|
"network update is %d behind the last applied update", h.updateSerial-serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasDNSRoute, routes := toRoutes(networkMap.GetRoutes())
|
||||||
|
DNSConfig := toDNSConfig(networkMap.GetDNSConfig())
|
||||||
|
|
||||||
|
if err := h.DNSServer.UpdateDNSServer(DNSConfig, hasDNSRoute); err != nil {
|
||||||
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.updateSerial = serial
|
||||||
|
|
||||||
|
// todo: consider to eliminate the serial management from the client.go
|
||||||
|
_, err := h.RouteManager.UpdateRoutes(serial, routes)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to update routes, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasDNSRoute {
|
||||||
|
if err := h.allowDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := h.dropDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *networkMapHandler) allowDNSFirewall() error {
|
||||||
|
dport := &firewall.Port{
|
||||||
|
IsRange: false,
|
||||||
|
Values: []int{h.DNSServer.DnsPort()},
|
||||||
|
}
|
||||||
|
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.dnsRules = dnsRules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *networkMapHandler) dropDNSFirewall() error {
|
||||||
|
if len(h.dnsRules) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range h.dnsRules {
|
||||||
|
if err := h.Firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
log.Errorf("failed to delete DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.dnsRules = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||||
|
if protoDNSConfig == nil {
|
||||||
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsUpdate := nbdns.Config{
|
||||||
|
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||||
|
CustomZones: make([]nbdns.CustomZone, 0),
|
||||||
|
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||||
|
dnsZone := nbdns.CustomZone{
|
||||||
|
Domain: zone.GetDomain(),
|
||||||
|
}
|
||||||
|
for _, record := range zone.Records {
|
||||||
|
dnsRecord := nbdns.SimpleRecord{
|
||||||
|
Name: record.GetName(),
|
||||||
|
Type: int(record.GetType()),
|
||||||
|
Class: record.GetClass(),
|
||||||
|
TTL: int(record.GetTTL()),
|
||||||
|
RData: record.GetRData(),
|
||||||
|
}
|
||||||
|
dnsZone.Records = append(dnsZone.Records, dnsRecord)
|
||||||
|
}
|
||||||
|
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
|
||||||
|
dnsNSGroup := &nbdns.NameServerGroup{
|
||||||
|
Primary: nsGroup.GetPrimary(),
|
||||||
|
Domains: nsGroup.GetDomains(),
|
||||||
|
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
|
||||||
|
}
|
||||||
|
for _, ns := range nsGroup.GetNameServers() {
|
||||||
|
dnsNS := nbdns.NameServer{
|
||||||
|
IP: netip.MustParseAddr(ns.GetIP()),
|
||||||
|
NSType: nbdns.NameServerType(ns.GetNSType()),
|
||||||
|
Port: int(ns.GetPort()),
|
||||||
|
}
|
||||||
|
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
|
||||||
|
}
|
||||||
|
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
|
||||||
|
}
|
||||||
|
return dnsUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
var hasDNSRoute bool
|
||||||
|
routes := make([]*route.Route, 0)
|
||||||
|
for _, protoRoute := range protoRoutes {
|
||||||
|
var prefix netip.Prefix
|
||||||
|
if len(protoRoute.Domains) == 0 {
|
||||||
|
var err error
|
||||||
|
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
||||||
|
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasDNSRoute = true
|
||||||
|
|
||||||
|
convertedRoute := &route.Route{
|
||||||
|
ID: route.ID(protoRoute.ID),
|
||||||
|
Network: prefix,
|
||||||
|
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||||
|
NetID: route.NetID(protoRoute.NetID),
|
||||||
|
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||||
|
Peer: protoRoute.Peer,
|
||||||
|
Metric: int(protoRoute.Metric),
|
||||||
|
Masquerade: protoRoute.Masquerade,
|
||||||
|
KeepRoute: protoRoute.KeepRoute,
|
||||||
|
}
|
||||||
|
routes = append(routes, convertedRoute)
|
||||||
|
}
|
||||||
|
return hasDNSRoute, routes
|
||||||
|
}
|
@ -33,7 +33,7 @@ import (
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (route.HAMap, error)
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
@ -60,6 +60,7 @@ type DefaultManager struct {
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
dnsRule []firewall.Rule // todo: remove rule in stop action
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(
|
||||||
@ -210,11 +211,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (route.HAMap, error) {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
return nil, nil, m.ctx.Err()
|
return nil, m.ctx.Err()
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
@ -226,13 +227,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("update routes: %w", err)
|
||||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
return newClientRoutesIDMap, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
_, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
_, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
|
Loading…
Reference in New Issue
Block a user