[client] Remove strings from allowed IPs (#3920)

This commit is contained in:
Viktor Liu
2025-06-10 14:26:28 +02:00
committed by GitHub
parent de27d6df36
commit 6127a01196
9 changed files with 68 additions and 52 deletions

View File

@ -0,0 +1,17 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@ -5,6 +5,7 @@ package configurer
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -45,7 +46,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@ -54,7 +55,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: prefixesToIPNets(allowedIps),
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
@ -91,10 +92,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil return nil
} }
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return err Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@ -105,7 +106,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@ -118,10 +119,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
return nil return nil
} }
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return fmt.Errorf("parse allowed IP: %w", err) Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@ -189,7 +190,11 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil { if err != nil {
return err return err
} }
defer wg.Close() defer func() {
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
// validate if device with name exists // validate if device with name exists
_, err = wg.Device(c.deviceName) _, err = wg.Device(c.deviceName)

View File

@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@ -67,7 +68,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@ -76,7 +77,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: prefixesToIPNets(allowedIps),
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
@ -106,10 +107,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return err Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@ -120,7 +121,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@ -130,7 +131,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipc, err := c.device.IpcGet() ipc, err := c.device.IpcGet()
if err != nil { if err != nil {
return err return err
@ -153,6 +154,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
foundPeer := false foundPeer := false
removedAllowedIP := false removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -175,8 +178,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
// Append the line to the output string // Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") { if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIP := strings.TrimPrefix(line, "allowed_ip=") allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP) _, ipNet, err := net.ParseCIDR(allowedIPStr)
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -11,10 +12,10 @@ import (
type WGConfigurer interface { type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)

View File

@ -111,14 +111,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
} }
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional // Endpoint is optional.
// If allowedIps is given it will be added to the existing ones.
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@ -131,7 +131,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
} }
// AddAllowedIP adds a prefix to the allowed IPs list of peer // AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@ -140,7 +140,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
} }
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer // RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@ -254,14 +254,3 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet() return w.tun.GetNet()
} }
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@ -86,8 +86,8 @@ type MockWGIface struct {
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
CloseFunc func() error CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
@ -147,11 +147,11 @@ func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey) return m.RemovePeerFunc(peerKey)
} }
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
return m.AddAllowedIPFunc(peerKey, allowedIP) return m.AddAllowedIPFunc(peerKey, allowedIP)
} }
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP) return m.RemoveAllowedIPFunc(peerKey, allowedIP)
} }

View File

@ -28,8 +28,8 @@ type wgIfaceBase interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
Close() error Close() error
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@ -2,14 +2,15 @@ package iface
import ( import (
"net" "net"
"net/netip"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address

View File

@ -159,10 +159,10 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.allowedIPsRefCounter = refcounter.New( m.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) { func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function // save peerKey to use it in the remove function
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String()) return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix)
}, },
func(prefix netip.Prefix, peerKey string) error { func(prefix netip.Prefix, peerKey string) error {
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix); err != nil {
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
return err return err
} }