mirror of
https://github.com/netbirdio/netbird.git
synced 2025-07-28 11:09:14 +02:00
520 lines
12 KiB
Go
520 lines
12 KiB
Go
package configurer
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
)
|
|
|
|
const (
|
|
privateKey = "private_key"
|
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
|
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
|
ipcKeyTxBytes = "tx_bytes"
|
|
ipcKeyRxBytes = "rx_bytes"
|
|
allowedIP = "allowed_ip"
|
|
endpoint = "endpoint"
|
|
fwmark = "fwmark"
|
|
listenPort = "listen_port"
|
|
publicKey = "public_key"
|
|
presharedKey = "preshared_key"
|
|
)
|
|
|
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
|
|
|
type WGUSPConfigurer struct {
|
|
device *device.Device
|
|
deviceName string
|
|
|
|
uapiListener net.Listener
|
|
}
|
|
|
|
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
|
wgCfg := &WGUSPConfigurer{
|
|
device: device,
|
|
deviceName: deviceName,
|
|
}
|
|
wgCfg.startUAPI()
|
|
return wgCfg
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
|
log.Debugf("adding Wireguard private key")
|
|
key, err := wgtypes.ParseKey(privateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fwmark := getFwmark()
|
|
config := wgtypes.Config{
|
|
PrivateKey: &key,
|
|
ReplacePeers: true,
|
|
FirewallMark: &fwmark,
|
|
ListenPort: &port,
|
|
}
|
|
|
|
return c.device.IpcSet(toWgUserspaceString(config))
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer := wgtypes.PeerConfig{
|
|
PublicKey: peerKeyParsed,
|
|
ReplaceAllowedIPs: false,
|
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
|
AllowedIPs: prefixesToIPNets(allowedIps),
|
|
PersistentKeepaliveInterval: &keepAlive,
|
|
PresharedKey: preSharedKey,
|
|
Endpoint: endpoint,
|
|
}
|
|
|
|
config := wgtypes.Config{
|
|
Peers: []wgtypes.PeerConfig{peer},
|
|
}
|
|
|
|
return c.device.IpcSet(toWgUserspaceString(config))
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
peer := wgtypes.PeerConfig{
|
|
PublicKey: peerKeyParsed,
|
|
Remove: true,
|
|
}
|
|
|
|
config := wgtypes.Config{
|
|
Peers: []wgtypes.PeerConfig{peer},
|
|
}
|
|
return c.device.IpcSet(toWgUserspaceString(config))
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
|
ipNet := net.IPNet{
|
|
IP: allowedIP.Addr().AsSlice(),
|
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
|
}
|
|
|
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer := wgtypes.PeerConfig{
|
|
PublicKey: peerKeyParsed,
|
|
UpdateOnly: true,
|
|
ReplaceAllowedIPs: false,
|
|
AllowedIPs: []net.IPNet{ipNet},
|
|
}
|
|
|
|
config := wgtypes.Config{
|
|
Peers: []wgtypes.PeerConfig{peer},
|
|
}
|
|
|
|
return c.device.IpcSet(toWgUserspaceString(config))
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
|
ipc, err := c.device.IpcGet()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
|
|
|
lines := strings.Split(ipc, "\n")
|
|
|
|
peer := wgtypes.PeerConfig{
|
|
PublicKey: peerKeyParsed,
|
|
UpdateOnly: true,
|
|
ReplaceAllowedIPs: true,
|
|
AllowedIPs: []net.IPNet{},
|
|
}
|
|
|
|
foundPeer := false
|
|
removedAllowedIP := false
|
|
ip := allowedIP.String()
|
|
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
|
|
// If we're within the details of the found peer and encounter another public key,
|
|
// this means we're starting another peer's details. So, reset the flag.
|
|
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
|
foundPeer = false
|
|
}
|
|
|
|
// Identify the peer with the specific public key
|
|
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
|
foundPeer = true
|
|
}
|
|
|
|
// If we're within the details of the found peer and find the specific allowed IP, skip this line
|
|
if foundPeer && line == "allowed_ip="+ip {
|
|
removedAllowedIP = true
|
|
continue
|
|
}
|
|
|
|
// Append the line to the output string
|
|
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
|
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
|
|
_, ipNet, err := net.ParseCIDR(allowedIPStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer.AllowedIPs = append(peer.AllowedIPs, *ipNet)
|
|
}
|
|
}
|
|
|
|
if !removedAllowedIP {
|
|
return ErrAllowedIPNotFound
|
|
}
|
|
config := wgtypes.Config{
|
|
Peers: []wgtypes.PeerConfig{peer},
|
|
}
|
|
return c.device.IpcSet(toWgUserspaceString(config))
|
|
}
|
|
|
|
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
|
ipcStr, err := c.device.IpcGet()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
|
}
|
|
|
|
return parseStatus(c.deviceName, ipcStr)
|
|
}
|
|
|
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
|
func (t *WGUSPConfigurer) startUAPI() {
|
|
var err error
|
|
t.uapiListener, err = openUAPI(t.deviceName)
|
|
if err != nil {
|
|
log.Errorf("failed to open uapi listener: %v", err)
|
|
return
|
|
}
|
|
|
|
go func(uapi net.Listener) {
|
|
for {
|
|
uapiConn, uapiErr := uapi.Accept()
|
|
if uapiErr != nil {
|
|
log.Tracef("%s", uapiErr)
|
|
return
|
|
}
|
|
go func() {
|
|
t.device.IpcHandle(uapiConn)
|
|
}()
|
|
}
|
|
}(t.uapiListener)
|
|
}
|
|
|
|
func (t *WGUSPConfigurer) Close() {
|
|
if t.uapiListener != nil {
|
|
err := t.uapiListener.Close()
|
|
if err != nil {
|
|
log.Errorf("failed to close uapi listener: %v", err)
|
|
}
|
|
}
|
|
|
|
if runtime.GOOS == "linux" {
|
|
sockPath := "/var/run/wireguard/" + t.deviceName + ".sock"
|
|
if _, statErr := os.Stat(sockPath); statErr == nil {
|
|
_ = os.Remove(sockPath)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
|
ipc, err := t.device.IpcGet()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ipc get: %w", err)
|
|
}
|
|
|
|
return parseTransfers(ipc)
|
|
}
|
|
|
|
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
|
stats := make(map[string]WGStats)
|
|
var (
|
|
currentKey string
|
|
currentStats WGStats
|
|
hasPeer bool
|
|
)
|
|
lines := strings.Split(ipc, "\n")
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
|
|
// If we're within the details of the found peer and encounter another public key,
|
|
// this means we're starting another peer's details. So, stop.
|
|
if strings.HasPrefix(line, "public_key=") {
|
|
peerID := strings.TrimPrefix(line, "public_key=")
|
|
h, err := hex.DecodeString(peerID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode peerID: %w", err)
|
|
}
|
|
currentKey = base64.StdEncoding.EncodeToString(h)
|
|
currentStats = WGStats{} // Reset stats for the new peer
|
|
hasPeer = true
|
|
stats[currentKey] = currentStats
|
|
continue
|
|
}
|
|
|
|
if !hasPeer {
|
|
continue
|
|
}
|
|
|
|
key := strings.SplitN(line, "=", 2)
|
|
if len(key) != 2 {
|
|
continue
|
|
}
|
|
switch key[0] {
|
|
case ipcKeyLastHandshakeTimeSec:
|
|
hs, err := toLastHandshake(key[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
currentStats.LastHandshake = hs
|
|
stats[currentKey] = currentStats
|
|
case ipcKeyRxBytes:
|
|
rxBytes, err := toBytes(key[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
|
}
|
|
currentStats.RxBytes = rxBytes
|
|
stats[currentKey] = currentStats
|
|
case ipcKeyTxBytes:
|
|
TxBytes, err := toBytes(key[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
|
}
|
|
currentStats.TxBytes = TxBytes
|
|
stats[currentKey] = currentStats
|
|
}
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|
var sb strings.Builder
|
|
if wgCfg.PrivateKey != nil {
|
|
hexKey := hex.EncodeToString(wgCfg.PrivateKey[:])
|
|
sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey))
|
|
}
|
|
|
|
if wgCfg.ListenPort != nil {
|
|
sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort))
|
|
}
|
|
|
|
if wgCfg.ReplacePeers {
|
|
sb.WriteString("replace_peers=true\n")
|
|
}
|
|
|
|
if wgCfg.FirewallMark != nil {
|
|
sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark))
|
|
}
|
|
|
|
for _, p := range wgCfg.Peers {
|
|
hexKey := hex.EncodeToString(p.PublicKey[:])
|
|
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
|
|
|
if p.PresharedKey != nil {
|
|
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
|
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
|
}
|
|
|
|
if p.Remove {
|
|
sb.WriteString("remove=true")
|
|
}
|
|
|
|
if p.ReplaceAllowedIPs {
|
|
sb.WriteString("replace_allowed_ips=true\n")
|
|
}
|
|
|
|
for _, aip := range p.AllowedIPs {
|
|
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
|
}
|
|
|
|
if p.Endpoint != nil {
|
|
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
|
}
|
|
|
|
if p.PersistentKeepaliveInterval != nil {
|
|
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func toLastHandshake(stringVar string) (time.Time, error) {
|
|
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
|
}
|
|
return time.Unix(sec, 0), nil
|
|
}
|
|
|
|
func toBytes(s string) (int64, error) {
|
|
return strconv.ParseInt(s, 10, 64)
|
|
}
|
|
|
|
func getFwmark() int {
|
|
if nbnet.AdvancedRouting() {
|
|
return nbnet.ControlPlaneMark
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
|
// Decode hex string to bytes
|
|
keyBytes, err := hex.DecodeString(hexKey)
|
|
if err != nil {
|
|
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
|
}
|
|
|
|
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
|
if len(keyBytes) != 32 {
|
|
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
|
}
|
|
|
|
// Convert to wgtypes.Key
|
|
var key wgtypes.Key
|
|
copy(key[:], keyBytes)
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|
stats := &Stats{DeviceName: deviceName}
|
|
var currentPeer *Peer
|
|
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
|
if line == "" {
|
|
continue
|
|
}
|
|
parts := strings.SplitN(line, "=", 2)
|
|
if len(parts) != 2 {
|
|
continue
|
|
}
|
|
key := parts[0]
|
|
val := parts[1]
|
|
|
|
switch key {
|
|
case privateKey:
|
|
key, err := hexToWireguardKey(val)
|
|
if err != nil {
|
|
log.Errorf("failed to parse private key: %v", err)
|
|
continue
|
|
}
|
|
stats.PublicKey = key.PublicKey().String()
|
|
case publicKey:
|
|
// Save previous peer
|
|
if currentPeer != nil {
|
|
stats.Peers = append(stats.Peers, *currentPeer)
|
|
}
|
|
key, err := hexToWireguardKey(val)
|
|
if err != nil {
|
|
log.Errorf("failed to parse public key: %v", err)
|
|
continue
|
|
}
|
|
currentPeer = &Peer{
|
|
PublicKey: key.String(),
|
|
}
|
|
case listenPort:
|
|
if port, err := strconv.Atoi(val); err == nil {
|
|
stats.ListenPort = port
|
|
}
|
|
case fwmark:
|
|
if fwmark, err := strconv.Atoi(val); err == nil {
|
|
stats.FWMark = fwmark
|
|
}
|
|
case endpoint:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
|
|
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
|
if err != nil {
|
|
log.Errorf("failed to parse endpoint: %v", err)
|
|
continue
|
|
}
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
log.Errorf("failed to parse endpoint port: %v", err)
|
|
continue
|
|
}
|
|
currentPeer.Endpoint = net.UDPAddr{
|
|
IP: net.ParseIP(host),
|
|
Port: port,
|
|
}
|
|
case allowedIP:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
_, ipnet, err := net.ParseCIDR(val)
|
|
if err == nil {
|
|
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
|
}
|
|
case ipcKeyTxBytes:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
rxBytes, err := toBytes(val)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
currentPeer.TxBytes = rxBytes
|
|
case ipcKeyRxBytes:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
rxBytes, err := toBytes(val)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
currentPeer.RxBytes = rxBytes
|
|
|
|
case ipcKeyLastHandshakeTimeSec:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
|
|
ts, err := toLastHandshake(val)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
currentPeer.LastHandshake = ts
|
|
case presharedKey:
|
|
if currentPeer == nil {
|
|
continue
|
|
}
|
|
if val != "" {
|
|
currentPeer.PresharedKey = true
|
|
}
|
|
}
|
|
}
|
|
if currentPeer != nil {
|
|
stats.Peers = append(stats.Peers, *currentPeer)
|
|
}
|
|
return stats, nil
|
|
}
|