mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-14 09:18:51 +02:00
[client] Replace string to netip.Prefix (#3362)
Replace string to netip.Prefix --------- Co-authored-by: Hakan Sariman <hknsrmn46@gmail.com>
This commit is contained in:
@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
// parse allowed ips
|
||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
AllowedIPs: allowedIps,
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
Endpoint: endpoint,
|
||||
PresharedKey: preSharedKey,
|
||||
|
@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
// parse allowed ips
|
||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
AllowedIPs: allowedIps,
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
PresharedKey: preSharedKey,
|
||||
Endpoint: endpoint,
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
|
||||
type WGConfigurer interface {
|
||||
ConfigureInterface(privateKey string, port int) error
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
|
@ -3,6 +3,7 @@ package iface
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
|
||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||
// Endpoint is optional
|
||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
netIPNets := prefixesToIPNets(allowedIps)
|
||||
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net {
|
||||
|
||||
return w.tun.GetNet()
|
||||
}
|
||||
|
||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||
ipNets := make([]net.IPNet, len(prefixes))
|
||||
for i, prefix := range prefixes {
|
||||
ipNets[i] = net.IPNet{
|
||||
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
|
||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
||||
}
|
||||
}
|
||||
return ipNets
|
||||
}
|
||||
|
@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.10/32"
|
||||
allowedIP := netip.MustParsePrefix("10.99.99.10/32")
|
||||
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil)
|
||||
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
|
||||
var foundAllowedIP bool
|
||||
for _, aip := range peer.AllowedIPs {
|
||||
if aip.String() == allowedIP {
|
||||
if aip.String() == allowedIP.String() {
|
||||
foundAllowedIP = true
|
||||
break
|
||||
}
|
||||
@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.14/32"
|
||||
|
||||
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
|
||||
allowedIP := netip.MustParsePrefix("10.99.99.14/32")
|
||||
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) {
|
||||
|
||||
func Test_ConnectPeers(t *testing.T) {
|
||||
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
|
||||
peer1wgIP := "10.99.99.17/30"
|
||||
peer1wgIP := netip.MustParsePrefix("10.99.99.17/30")
|
||||
peer1Key, _ := wgtypes.GeneratePrivateKey()
|
||||
peer1wgPort := 33100
|
||||
|
||||
peer2ifaceName := "utun500"
|
||||
peer2wgIP := "10.99.99.18/30"
|
||||
peer2wgIP := netip.MustParsePrefix("10.99.99.18/30")
|
||||
peer2Key, _ := wgtypes.GeneratePrivateKey()
|
||||
peer2wgPort := 33200
|
||||
|
||||
@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer1 := WGIFaceOpts{
|
||||
IFaceName: peer1ifaceName,
|
||||
Address: peer1wgIP,
|
||||
Address: peer1wgIP.String(),
|
||||
WGPort: peer1wgPort,
|
||||
WGPrivKey: peer1Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer2 := WGIFaceOpts{
|
||||
IFaceName: peer2ifaceName,
|
||||
Address: peer2wgIP,
|
||||
Address: peer2wgIP.String(),
|
||||
WGPort: peer2wgPort,
|
||||
WGPrivKey: peer2Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil)
|
||||
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil)
|
||||
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -527,15 +527,18 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
|
||||
if allowedIPs != strings.Join(p.AllowedIps, ",") {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
||||
if err != nil {
|
||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||
}
|
||||
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
|
||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
||||
if err != nil {
|
||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1103,34 +1106,45 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
// addNewPeer add peer if connection doesn't exist
|
||||
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
peerKey := peerConfig.GetWgPubKey()
|
||||
peerIPs := peerConfig.GetAllowedIps()
|
||||
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
|
||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
conn.Open()
|
||||
peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps()))
|
||||
if _, ok := e.peerStore.PeerConn(peerKey); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ipString := range peerConfig.GetAllowedIps() {
|
||||
allowedNetIP, err := netip.ParsePrefix(ipString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return err
|
||||
}
|
||||
peerIPs = append(peerIPs, allowedNetIP)
|
||||
}
|
||||
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
conn.Open()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
|
||||
log.Debugf("creating peer connection %s", pubKey)
|
||||
|
||||
wgConfig := peer.WgConfig{
|
||||
@ -1815,3 +1829,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
|
||||
|
||||
return prefixes, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// compareNetIPLists compares a list of netip.Prefix with a list of strings.
|
||||
// return true if both lists are equal, false otherwise.
|
||||
func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool {
|
||||
if len(list1) != len(list2) {
|
||||
return false
|
||||
}
|
||||
|
||||
freq := make(map[string]int, len(list1))
|
||||
for _, p := range list1 {
|
||||
freq[p.String()]++
|
||||
}
|
||||
|
||||
for _, s := range list2 {
|
||||
p, err := netip.ParsePrefix(s)
|
||||
if err != nil {
|
||||
return false // invalid prefix in list2.
|
||||
}
|
||||
key := p.String()
|
||||
if freq[key] == 0 {
|
||||
return false
|
||||
}
|
||||
freq[key]--
|
||||
}
|
||||
|
||||
// all counts should be zero if lists are equal.
|
||||
for _, count := range freq {
|
||||
if count != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -77,7 +78,7 @@ type MockWGIface struct {
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
@ -128,7 +129,7 @@ func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
@ -534,7 +535,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||
}
|
||||
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
||||
if conn.WgConfig().AllowedIps != expectedAllowedIPs {
|
||||
if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) {
|
||||
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
||||
expectedAllowedIPs, conn.WgConfig().AllowedIps)
|
||||
}
|
||||
@ -1237,6 +1238,91 @@ func Test_CheckFilesEqual(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareNetIPLists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
list1 []netip.Prefix
|
||||
list2 []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "both empty",
|
||||
list1: []netip.Prefix{},
|
||||
list2: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "single match ipv4",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||
list2: []string{"192.168.0.0/24"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple match ipv4, different order",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")},
|
||||
list2: []string{"10.0.0.0/8", "192.168.1.0/24"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv4 mismatch due to extra element in list2",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
list2: []string{"192.168.1.0/24", "10.0.0.0/8"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ipv4 mismatch due to duplicate count",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")},
|
||||
list2: []string{"192.168.1.0/24"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid prefix in list2",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
list2: []string{"invalid-prefix"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ipv4 mismatch because different prefixes",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
list2: []string{"10.0.0.0/8"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "single match ipv6",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
|
||||
list2: []string{"2001:db8::/32"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple match ipv6, different order",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")},
|
||||
list2: []string{"fe80::/10", "2001:db8::/32"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "mixed ipv4 and ipv6 match",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")},
|
||||
list2: []string{"2001:db8::/32", "192.168.1.0/24"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 mismatch with invalid prefix",
|
||||
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
|
||||
list2: []string{"invalid-ipv6"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := compareNetIPLists(tt.list1, tt.list2)
|
||||
if result != tt.expected {
|
||||
t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
@ -24,7 +25,7 @@ type wgIfaceBase interface {
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
|
@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -56,7 +56,7 @@ type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface WGIface
|
||||
AllowedIps string
|
||||
AllowedIps []netip.Prefix
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
@ -91,11 +91,10 @@ type Conn struct {
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
relayManager *relayClient.Manager
|
||||
allowedIP net.IP
|
||||
handshaker *Handshaker
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string, wgIP string)
|
||||
onDisconnected func(remotePeer string)
|
||||
|
||||
statusRelay *AtomicConnStatus
|
||||
statusICE *AtomicConnStatus
|
||||
@ -120,10 +119,8 @@ type Conn struct {
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
||||
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return nil, err
|
||||
if len(config.WgConfig.AllowedIps) == 0 {
|
||||
return nil, fmt.Errorf("allowed IPs is empty")
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(engineCtx)
|
||||
@ -137,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
statusRecorder: statusRecorder,
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIP: allowedIP,
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
semaphore: semaphore,
|
||||
@ -147,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
@ -179,7 +176,7 @@ func (conn *Conn) Open() {
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: StatusDisconnected,
|
||||
Mux: new(sync.RWMutex),
|
||||
@ -245,7 +242,7 @@ func (conn *Conn) Close() {
|
||||
conn.freeUpConnID()
|
||||
|
||||
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
|
||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
@ -276,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR
|
||||
}
|
||||
|
||||
// SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected
|
||||
func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) {
|
||||
func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
||||
conn.onDisconnected = handler
|
||||
}
|
||||
|
||||
@ -601,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -698,7 +695,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: conn.allowedIP,
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
|
||||
Port: conn.config.WgConfig.WgListenPort,
|
||||
}
|
||||
|
||||
@ -752,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
}
|
||||
|
||||
// AllowedIP returns the allowed IP of the remote peer
|
||||
func (conn *Conn) AllowedIP() net.IP {
|
||||
return conn.allowedIP
|
||||
func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
|
@ -2,15 +2,17 @@ package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
|
@ -1,7 +1,7 @@
|
||||
package peerstore
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
|
||||
return p, true
|
||||
}
|
||||
|
||||
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return p.WgConfig().AllowedIps, true
|
||||
}
|
||||
|
||||
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||
func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return p.WgConfig().AllowedIps, true
|
||||
}
|
||||
|
||||
func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
return p.AllowedIP(), true
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (m *Manager) OnDisconnected(peerKey string, wgIP string) {
|
||||
func (m *Manager) OnDisconnected(peerKey string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
|
@ -3,7 +3,6 @@ package dnsinterceptor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
Timeout: 5 * time.Second,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||
return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||
}
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user