mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-24 23:59:25 +01:00
337d3edcc4
The ConnStatus is a custom type based on iota like an enum. The problem was nowhere used to the benefits of this implementation. All ConnStatus instances has been compared with strings. I suppose the reason to do it to avoid a circle dependency. In this commit the separated status package has been moved to peer package. Remove unused, exported functions from engine
283 lines
7.1 KiB
Go
283 lines
7.1 KiB
Go
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/netip"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/iface"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
type routerPeerStatus struct {
|
|
connected bool
|
|
relayed bool
|
|
direct bool
|
|
}
|
|
|
|
type routesUpdate struct {
|
|
updateSerial uint64
|
|
routes []*route.Route
|
|
}
|
|
|
|
type clientNetwork struct {
|
|
ctx context.Context
|
|
stop context.CancelFunc
|
|
statusRecorder *peer.Status
|
|
wgInterface *iface.WGIface
|
|
routes map[string]*route.Route
|
|
routeUpdate chan routesUpdate
|
|
peerStateUpdate chan struct{}
|
|
routePeersNotifiers map[string]chan struct{}
|
|
chosenRoute *route.Route
|
|
network netip.Prefix
|
|
updateSerial uint64
|
|
}
|
|
|
|
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
client := &clientNetwork{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
statusRecorder: statusRecorder,
|
|
wgInterface: wgInterface,
|
|
routes: make(map[string]*route.Route),
|
|
routePeersNotifiers: make(map[string]chan struct{}),
|
|
routeUpdate: make(chan routesUpdate),
|
|
peerStateUpdate: make(chan struct{}),
|
|
network: network,
|
|
}
|
|
return client
|
|
}
|
|
|
|
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|
routePeerStatuses := make(map[string]routerPeerStatus)
|
|
for _, r := range c.routes {
|
|
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
|
if err != nil {
|
|
log.Debugf("couldn't fetch peer state: %v", err)
|
|
continue
|
|
}
|
|
routePeerStatuses[r.ID] = routerPeerStatus{
|
|
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
|
relayed: peerStatus.Relayed,
|
|
direct: peerStatus.Direct,
|
|
}
|
|
}
|
|
return routePeerStatuses
|
|
}
|
|
|
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
|
var chosen string
|
|
chosenScore := 0
|
|
|
|
currID := ""
|
|
if c.chosenRoute != nil {
|
|
currID = c.chosenRoute.ID
|
|
}
|
|
|
|
for _, r := range c.routes {
|
|
tempScore := 0
|
|
peerStatus, found := routePeerStatuses[r.ID]
|
|
if !found || !peerStatus.connected {
|
|
continue
|
|
}
|
|
if r.Metric < route.MaxMetric {
|
|
metricDiff := route.MaxMetric - r.Metric
|
|
tempScore = metricDiff * 10
|
|
}
|
|
if !peerStatus.relayed {
|
|
tempScore++
|
|
}
|
|
if !peerStatus.direct {
|
|
tempScore++
|
|
}
|
|
if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) {
|
|
chosen = r.ID
|
|
chosenScore = tempScore
|
|
}
|
|
}
|
|
|
|
if chosen == "" {
|
|
var peers []string
|
|
for _, r := range c.routes {
|
|
peers = append(peers, r.Peer)
|
|
}
|
|
log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers)
|
|
} else if chosen != currID {
|
|
log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore)
|
|
}
|
|
|
|
return chosen
|
|
}
|
|
|
|
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-closer:
|
|
return
|
|
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
|
if err != nil || state.ConnStatus == peer.StatusConnecting {
|
|
continue
|
|
}
|
|
peerStateUpdate <- struct{}{}
|
|
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|
for _, r := range c.routes {
|
|
_, found := c.routePeersNotifiers[r.Peer]
|
|
if !found {
|
|
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
|
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
|
if err != nil || state.ConnStatus != peer.StatusConnected {
|
|
return nil
|
|
}
|
|
|
|
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
|
c.network, c.chosenRoute.Peer, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
|
if c.chosenRoute != nil {
|
|
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
|
c.network, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|
|
|
var err error
|
|
|
|
routerPeerStatuses := c.getRouterPeerStatuses()
|
|
|
|
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
|
if chosen == "" {
|
|
err = c.removeRouteFromPeerAndSystem()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.chosenRoute = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
|
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if c.chosenRoute != nil {
|
|
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
|
if err != nil {
|
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
|
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
|
}
|
|
}
|
|
|
|
c.chosenRoute = c.routes[chosen]
|
|
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
|
if err != nil {
|
|
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
|
c.network, c.chosenRoute.Peer, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
|
go func() {
|
|
c.routeUpdate <- update
|
|
}()
|
|
}
|
|
|
|
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
|
updateMap := make(map[string]*route.Route)
|
|
|
|
for _, r := range update.routes {
|
|
updateMap[r.ID] = r
|
|
}
|
|
|
|
for id, r := range c.routes {
|
|
_, found := updateMap[id]
|
|
if !found {
|
|
close(c.routePeersNotifiers[r.Peer])
|
|
delete(c.routePeersNotifiers, r.Peer)
|
|
}
|
|
}
|
|
|
|
c.routes = updateMap
|
|
}
|
|
|
|
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
|
|
// All the processing related to the client network should be done here. Thread-safe.
|
|
func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
log.Debugf("stopping watcher for network %s", c.network)
|
|
err := c.removeRouteFromPeerAndSystem()
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
return
|
|
case <-c.peerStateUpdate:
|
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
case update := <-c.routeUpdate:
|
|
if update.updateSerial < c.updateSerial {
|
|
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
|
continue
|
|
}
|
|
|
|
log.Debugf("received a new client network route update for %s", c.network)
|
|
|
|
c.handleUpdate(update)
|
|
|
|
c.updateSerial = update.updateSerial
|
|
|
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
|
|
c.startPeersStatusChangeWatcher()
|
|
}
|
|
}
|
|
}
|