[client] Replace string to netip.Prefix (#3362)

Replace string to netip.Prefix

---------

Co-authored-by: Hakan Sariman <hknsrmn46@gmail.com>
This commit is contained in:
Zoltan Papp
2025-02-24 15:51:43 +01:00
committed by GitHub
parent c8a558f797
commit 0819df916e
13 changed files with 238 additions and 106 deletions

View File

@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,

View File

@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,

View File

@ -11,7 +11,7 @@ import (
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@ -3,6 +3,7 @@ package iface
import (
"fmt"
"net"
"net/netip"
"sync"
"time"
@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet()
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := "10.99.99.10/32"
allowedIP := netip.MustParsePrefix("10.99.99.10/32")
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
if err != nil {
t.Fatal(err)
}
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil)
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil)
if err != nil {
t.Fatal(err)
}
@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) {
var foundAllowedIP bool
for _, aip := range peer.AllowedIPs {
if aip.String() == allowedIP {
if aip.String() == allowedIP.String() {
foundAllowedIP = true
break
}
@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := "10.99.99.14/32"
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
allowedIP := netip.MustParsePrefix("10.99.99.14/32")
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil)
if err != nil {
t.Fatal(err)
}
@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) {
func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
peer1wgIP := "10.99.99.17/30"
peer1wgIP := netip.MustParsePrefix("10.99.99.17/30")
peer1Key, _ := wgtypes.GeneratePrivateKey()
peer1wgPort := 33100
peer2ifaceName := "utun500"
peer2wgIP := "10.99.99.18/30"
peer2wgIP := netip.MustParsePrefix("10.99.99.18/30")
peer2Key, _ := wgtypes.GeneratePrivateKey()
peer2wgPort := 33200
@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: peer1wgIP,
Address: peer1wgIP.String(),
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: peer2wgIP,
Address: peer2wgIP.String(),
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) {
}
}()
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil)
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil)
if err != nil {
t.Fatal(err)
}
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil)
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -527,17 +527,20 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
if allowedIPs != strings.Join(p.AllowedIps, ",") {
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if !ok {
continue
}
if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) {
modified = append(modified, p)
continue
}
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
}
}
}
// second, close all modified connections and remove them from the state map
for _, p := range modified {
@ -1103,9 +1106,21 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// addNewPeer add peer if connection doesn't exist
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps()))
if _, ok := e.peerStore.PeerConn(peerKey); ok {
return nil
}
for _, ipString := range peerConfig.GetAllowedIps() {
allowedNetIP, err := netip.ParsePrefix(ipString)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
return err
}
peerIPs = append(peerIPs, allowedNetIP)
}
conn, err := e.createPeerConn(peerKey, peerIPs)
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
@ -1126,11 +1141,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
}
conn.Open()
}
return nil
}
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{
@ -1815,3 +1829,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
return prefixes, nberrors.FormatErrorOrNil(merr)
}
// compareNetIPLists compares a list of netip.Prefix with a list of strings.
// return true if both lists are equal, false otherwise.
func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool {
if len(list1) != len(list2) {
return false
}
freq := make(map[string]int, len(list1))
for _, p := range list1 {
freq[p.String()]++
}
for _, s := range list2 {
p, err := netip.ParsePrefix(s)
if err != nil {
return false // invalid prefix in list2.
}
key := p.String()
if freq[key] == 0 {
return false
}
freq[key]--
}
// all counts should be zero if lists are equal.
for _, count := range freq {
if count != 0 {
return false
}
}
return true
}

View File

@ -26,6 +26,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
@ -77,7 +78,7 @@ type MockWGIface struct {
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
@ -128,7 +129,7 @@ func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
@ -534,7 +535,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
}
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
if conn.WgConfig().AllowedIps != expectedAllowedIPs {
if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.WgConfig().AllowedIps)
}
@ -1237,6 +1238,91 @@ func Test_CheckFilesEqual(t *testing.T) {
}
}
func TestCompareNetIPLists(t *testing.T) {
tests := []struct {
name string
list1 []netip.Prefix
list2 []string
expected bool
}{
{
name: "both empty",
list1: []netip.Prefix{},
list2: []string{},
expected: true,
},
{
name: "single match ipv4",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
list2: []string{"192.168.0.0/24"},
expected: true,
},
{
name: "multiple match ipv4, different order",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")},
list2: []string{"10.0.0.0/8", "192.168.1.0/24"},
expected: true,
},
{
name: "ipv4 mismatch due to extra element in list2",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"192.168.1.0/24", "10.0.0.0/8"},
expected: false,
},
{
name: "ipv4 mismatch due to duplicate count",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"192.168.1.0/24"},
expected: false,
},
{
name: "invalid prefix in list2",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"invalid-prefix"},
expected: false,
},
{
name: "ipv4 mismatch because different prefixes",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
list2: []string{"10.0.0.0/8"},
expected: false,
},
{
name: "single match ipv6",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"2001:db8::/32"},
expected: true,
},
{
name: "multiple match ipv6, different order",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")},
list2: []string{"fe80::/10", "2001:db8::/32"},
expected: true,
},
{
name: "mixed ipv4 and ipv6 match",
list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"2001:db8::/32", "192.168.1.0/24"},
expected: true,
},
{
name: "ipv6 mismatch with invalid prefix",
list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
list2: []string{"invalid-ipv6"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := compareNetIPLists(tt.list1, tt.list2)
if result != tt.expected {
t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected)
}
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {

View File

@ -2,6 +2,7 @@ package internal
import (
"net"
"net/netip"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
@ -24,7 +25,7 @@ type wgIfaceBase interface {
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@ -5,9 +5,9 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync"
"time"
@ -56,7 +56,7 @@ type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface WGIface
AllowedIps string
AllowedIps []netip.Prefix
PreSharedKey *wgtypes.Key
}
@ -91,11 +91,10 @@ type Conn struct {
statusRecorder *Status
signaler *Signaler
relayManager *relayClient.Manager
allowedIP net.IP
handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
@ -120,10 +119,8 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err
if len(config.WgConfig.AllowedIps) == 0 {
return nil, fmt.Errorf("allowed IPs is empty")
}
ctx, ctxCancel := context.WithCancel(engineCtx)
@ -137,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRecorder: statusRecorder,
signaler: signaler,
relayManager: relayManager,
allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
@ -147,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
@ -179,7 +176,7 @@ func (conn *Conn) Open() {
peerState := State{
PubKey: conn.config.Key,
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected,
Mux: new(sync.RWMutex),
@ -245,7 +242,7 @@ func (conn *Conn) Close() {
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
}
conn.setStatusToDisconnected()
@ -276,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR
}
// SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected
func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) {
func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
@ -601,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
}
}
@ -698,7 +695,7 @@ func (conn *Conn) freeUpConnID() {
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
IP: conn.allowedIP,
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort,
}
@ -752,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
}
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP {
return conn.allowedIP
func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr()
}
func isController(config ConnConfig) bool {

View File

@ -2,15 +2,17 @@ package peer
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WGIface interface {
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy

View File

@ -1,7 +1,7 @@
package peerstore
import (
"net"
"net/netip"
"sync"
"golang.org/x/exp/maps"
@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
return p, true
}
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
if !ok {
return nil, false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return netip.Addr{}, false
}
return p.AllowedIP(), true
}

View File

@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil
}
func (m *Manager) OnDisconnected(peerKey string, wgIP string) {
func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock()
defer m.lock.Unlock()

View File

@ -3,7 +3,6 @@ package dnsinterceptor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg,
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}