[client] Replace string to netip.Prefix (#3362)

Replace string to netip.Prefix

---------

Co-authored-by: Hakan Sariman <hknsrmn46@gmail.com>
This commit is contained in:
Zoltan Papp
2025-02-24 15:51:43 +01:00
committed by GitHub
parent c8a558f797
commit 0819df916e
13 changed files with 238 additions and 106 deletions

View File

@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,

View File

@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,

View File

@ -11,7 +11,7 @@ import (
type WGConfigurer interface { type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@ -3,6 +3,7 @@ package iface
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional // Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet() return w.tun.GetNet()
} }
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
keepAlive := 15 * time.Second keepAlive := 15 * time.Second
allowedIP := "10.99.99.10/32" allowedIP := netip.MustParsePrefix("10.99.99.10/32")
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900") endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil) err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) {
var foundAllowedIP bool var foundAllowedIP bool
for _, aip := range peer.AllowedIPs { for _, aip := range peer.AllowedIPs {
if aip.String() == allowedIP { if aip.String() == allowedIP.String() {
foundAllowedIP = true foundAllowedIP = true
break break
} }
@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
keepAlive := 15 * time.Second keepAlive := 15 * time.Second
allowedIP := "10.99.99.14/32" allowedIP := netip.MustParsePrefix("10.99.99.14/32")
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil)
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) {
func Test_ConnectPeers(t *testing.T) { func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
peer1wgIP := "10.99.99.17/30" peer1wgIP := netip.MustParsePrefix("10.99.99.17/30")
peer1Key, _ := wgtypes.GeneratePrivateKey() peer1Key, _ := wgtypes.GeneratePrivateKey()
peer1wgPort := 33100 peer1wgPort := 33100
peer2ifaceName := "utun500" peer2ifaceName := "utun500"
peer2wgIP := "10.99.99.18/30" peer2wgIP := netip.MustParsePrefix("10.99.99.18/30")
peer2Key, _ := wgtypes.GeneratePrivateKey() peer2Key, _ := wgtypes.GeneratePrivateKey()
peer2wgPort := 33200 peer2wgPort := 33200
@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{ optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName, IFaceName: peer1ifaceName,
Address: peer1wgIP, Address: peer1wgIP.String(),
WGPort: peer1wgPort, WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(), WGPrivKey: peer1Key.String(),
MTU: DefaultMTU, MTU: DefaultMTU,
@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{ optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName, IFaceName: peer2ifaceName,
Address: peer2wgIP, Address: peer2wgIP.String(),
WGPort: peer2wgPort, WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(), WGPrivKey: peer2Key.String(),
MTU: DefaultMTU, MTU: DefaultMTU,
@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) {
} }
}() }()
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil) err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil) err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -527,15 +527,18 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if allowedIPs != strings.Join(p.AllowedIps, ",") { if !ok {
modified = append(modified, p) continue
continue }
} if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) {
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) modified = append(modified, p)
if err != nil { continue
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) }
}
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
} }
} }
@ -1103,34 +1106,45 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// addNewPeer add peer if connection doesn't exist // addNewPeer add peer if connection doesn't exist
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey() peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps() peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps()))
if _, ok := e.peerStore.PeerConn(peerKey); !ok { if _, ok := e.peerStore.PeerConn(peerKey); ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) return nil
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
} }
for _, ipString := range peerConfig.GetAllowedIps() {
allowedNetIP, err := netip.ParsePrefix(ipString)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
return err
}
peerIPs = append(peerIPs, allowedNetIP)
}
conn, err := e.createPeerConn(peerKey, peerIPs)
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
return nil return nil
} }
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) { func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey) log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
@ -1815,3 +1829,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
return prefixes, nberrors.FormatErrorOrNil(merr) return prefixes, nberrors.FormatErrorOrNil(merr)
} }
// compareNetIPLists compares a list of netip.Prefix with a list of strings.
// return true if both lists are equal, false otherwise.
func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool {
if len(list1) != len(list2) {
return false
}
freq := make(map[string]int, len(list1))
for _, p := range list1 {
freq[p.String()]++
}
for _, s := range list2 {
p, err := netip.ParsePrefix(s)
if err != nil {
return false // invalid prefix in list2.
}
key := p.String()
if freq[key] == 0 {
return false
}
freq[key]--
}
// all counts should be zero if lists are equal.
for _, count := range freq {
if count != 0 {
return false
}
}
return true
}

View File

@ -26,6 +26,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -77,7 +78,7 @@ type MockWGIface struct {
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
@ -128,7 +129,7 @@ func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr) return m.UpdateAddrFunc(newAddr)
} }
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
@ -534,7 +535,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Errorf("expecting Engine.peerConns to contain peer %s", p) t.Errorf("expecting Engine.peerConns to contain peer %s", p)
} }
expectedAllowedIPs := strings.Join(p.AllowedIps, ",") expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
if conn.WgConfig().AllowedIps != expectedAllowedIPs { if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(), t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.WgConfig().AllowedIps) expectedAllowedIPs, conn.WgConfig().AllowedIps)
} }
@ -1237,6 +1238,91 @@ func Test_CheckFilesEqual(t *testing.T) {
} }
} }
func TestCompareNetIPLists(t *testing.T) {
tests := []struct {
name string
list1 []netip.Prefix
list2 []string
expected bool
}{
{
name: "both empty",
list1: []netip.Prefix{},
list2: []string{},
expected: true,
},
{
name: "single match ipv4",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
list2: []string{"192.168.0.0/24"},
expected: true,
},
{
name: "multiple match ipv4, different order",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")},
list2: []string{"10.0.0.0/8", "192.168.1.0/24"},
expected: true,
},
{
name: "ipv4 mismatch due to extra element in list2",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"192.168.1.0/24", "10.0.0.0/8"},
expected: false,
},
{
name: "ipv4 mismatch due to duplicate count",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"192.168.1.0/24"},
expected: false,
},
{
name: "invalid prefix in list2",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"invalid-prefix"},
expected: false,
},
{
name: "ipv4 mismatch because different prefixes",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"10.0.0.0/8"},
expected: false,
},
{
name: "single match ipv6",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"2001:db8::/32"},
expected: true,
},
{
name: "multiple match ipv6, different order",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")},
list2: []string{"fe80::/10", "2001:db8::/32"},
expected: true,
},
{
name: "mixed ipv4 and ipv6 match",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"2001:db8::/32", "192.168.1.0/24"},
expected: true,
},
{
name: "ipv6 mismatch with invalid prefix",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"invalid-ipv6"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := compareNetIPLists(tt.list1, tt.list2)
if result != tt.expected {
t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected)
}
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package internal
import ( import (
"net" "net"
"net/netip"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
@ -24,7 +25,7 @@ type wgIfaceBase interface {
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"runtime" "runtime"
"strings"
"sync" "sync"
"time" "time"
@ -56,7 +56,7 @@ type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
WgInterface WGIface WgInterface WGIface
AllowedIps string AllowedIps []netip.Prefix
PreSharedKey *wgtypes.Key PreSharedKey *wgtypes.Key
} }
@ -91,11 +91,10 @@ type Conn struct {
statusRecorder *Status statusRecorder *Status
signaler *Signaler signaler *Signaler
relayManager *relayClient.Manager relayManager *relayClient.Manager
allowedIP net.IP
handshaker *Handshaker handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string) onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus statusICE *AtomicConnStatus
@ -120,10 +119,8 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if len(config.WgConfig.AllowedIps) == 0 {
if err != nil { return nil, fmt.Errorf("allowed IPs is empty")
log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err
} }
ctx, ctxCancel := context.WithCancel(engineCtx) ctx, ctxCancel := context.WithCancel(engineCtx)
@ -137,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
signaler: signaler, signaler: signaler,
relayManager: relayManager, relayManager: relayManager,
allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
semaphore: semaphore, semaphore: semaphore,
@ -147,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
@ -179,7 +176,7 @@ func (conn *Conn) Open() {
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected, ConnStatus: StatusDisconnected,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
@ -245,7 +242,7 @@ func (conn *Conn) Close() {
conn.freeUpConnID() conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
conn.setStatusToDisconnected() conn.setStatusToDisconnected()
@ -276,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR
} }
// SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected // SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected
func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) { func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler conn.onDisconnected = handler
} }
@ -601,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
if conn.onConnected != nil { if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
} }
} }
@ -698,7 +695,7 @@ func (conn *Conn) freeUpConnID() {
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection") conn.log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{
IP: conn.allowedIP, IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort, Port: conn.config.WgConfig.WgListenPort,
} }
@ -752,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
} }
// AllowedIP returns the allowed IP of the remote peer // AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP { func (conn *Conn) AllowedIP() netip.Addr {
return conn.allowedIP return conn.config.WgConfig.AllowedIps[0].Addr()
} }
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {

View File

@ -2,15 +2,17 @@ package peer
import ( import (
"net" "net"
"net/netip"
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type WGIface interface { type WGIface interface {
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy

View File

@ -1,7 +1,7 @@
package peerstore package peerstore
import ( import (
"net" "net/netip"
"sync" "sync"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
return p, true return p, true
} }
func (s *Store) AllowedIPs(pubKey string) (string, bool) { func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
s.peerConnsMu.RLock() s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock() defer s.peerConnsMu.RUnlock()
@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
if !ok { if !ok {
return nil, false return nil, false
} }
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return netip.Addr{}, false
}
return p.AllowedIP(), true return p.AllowedIP(), true
} }

View File

@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil return cfg, nil
} }
func (m *Manager) OnDisconnected(peerKey string, wgIP string) { func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()

View File

@ -3,7 +3,6 @@ package dnsinterceptor
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Net: "udp", Net: "udp",
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream) reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR var answer []dns.RR
if reply != nil { if reply != nil {
answer = reply.Answer answer = reply.Answer
} }
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg,
} }
} }
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists { if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey)
} }
return peerAllowedIP, nil return peerAllowedIP, nil
} }