Ignore candidates whose IP falls into a routed network. (#2084)

This will prevent peer connections via other peers.
This commit is contained in:
Viktor Liu 2024-06-03 17:31:37 +02:00 committed by GitHub
parent 456629811b
commit 9b3449753e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 6 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -118,6 +119,7 @@ type Engine struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context clientCtx context.Context
clientCancel context.CancelFunc clientCancel context.CancelFunc
@ -240,7 +242,9 @@ func (e *Engine) Stop() error {
return err return err
} }
e.clientRoutesMu.Lock()
e.clientRoutes = nil e.clientRoutes = nil
e.clientRoutesMu.Unlock()
// very ugly but we want to remove peers from the WireGuard interface first before removing interface. // very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously // Removing peers happens in the conn.Close() asynchronously
@ -738,7 +742,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
@ -1088,7 +1094,8 @@ func (e *Engine) receiveSignalEvents() {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err) log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err return err
} }
conn.OnRemoteCandidate(candidate)
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
} }
@ -1282,11 +1289,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
// GetClientRoutes returns the current routes from the route map // GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap { func (e *Engine) GetClientRoutes() route.HAMap {
return e.clientRoutes e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
} }
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes { for id, v := range e.clientRoutes {
routes[id.NetID()] = v routes[id.NetID()] = v

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -18,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@ -353,7 +355,7 @@ func (conn *Conn) Open(ctx context.Context) error {
err = conn.agent.GatherCandidates() err = conn.agent.GatherCandidates()
if err != nil { if err != nil {
return err return fmt.Errorf("gather candidates: %v", err)
} }
// will block until connection succeeded // will block until connection succeeded
@ -370,7 +372,7 @@ func (conn *Conn) Open(ctx context.Context) error {
return err return err
} }
// dynamically set remote WireGuard port is other side specified a different one from the default one // dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 { if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort remoteWgPort = remoteOfferAnswer.WgListenPort
@ -779,7 +781,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String()) log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
go func() { go func() {
conn.mu.Lock() conn.mu.Lock()
@ -789,6 +791,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
return return
} }
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := conn.agent.AddRemoteCandidate(candidate) err := conn.agent.AddRemoteCandidate(candidate)
if err != nil { if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key) log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
@ -806,3 +812,31 @@ func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
protoSupport := signal.ParseFeaturesSupported(support) protoSupport := signal.ParseFeaturesSupported(support)
conn.meta.protoSupport = protoSupport conn.meta.protoSupport = protoSupport
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}