Mobile prerefactor (#680)

Small code cleaning in the iface package. These changes necessary to 
get a clean code in case if we involve more platforms. The OS related 
functions has been distributed into separate files and it has been 
mixed with not OS related logic. The goal is to get a clear picture 
of the layer between WireGuard and business logic.
This commit is contained in:
Zoltan Papp 2023-02-13 18:34:56 +01:00 committed by GitHub
parent eb45310c8f
commit b64f5ffcb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 424 additions and 446 deletions

View File

@ -76,12 +76,12 @@ func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager,
} }
defer closeConn() defer closeConn()
var s string var s string
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s) err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName()) log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
return &networkManagerDbusConfigurator{ return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),

View File

@ -17,7 +17,7 @@ type resolvconf struct {
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
return &resolvconf{ return &resolvconf{
ifaceName: wgInterface.GetName(), ifaceName: wgInterface.Name(),
}, nil }, nil
} }

View File

@ -136,7 +136,7 @@ func (s *DefaultServer) Start() {
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP} ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil { if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.GetAddress().IP.String()}, ips...) ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
} }
ports := []int{defaultPort, customPort} ports := []int{defaultPort, customPort}
for _, port := range ports { for _, port := range ports {

View File

@ -51,7 +51,7 @@ type systemdDbusLinkDomainsInput struct {
} }
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.GetName()) iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -158,12 +158,10 @@ func (e *Engine) Stop() error {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface.Interface != nil { err = e.wgInterface.Close()
err = e.wgInterface.Close() if err != nil {
if err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) return err
return err
}
} }
if e.udpMux != nil { if e.udpMux != nil {
@ -501,7 +499,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
//nil sshServer means it has not yet been started //nil sshServer means it has not yet been started
var err error var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address.IP.String(), nbssh.DefaultSSHPort)) fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
if err != nil { if err != nil {
return err return err
} }
@ -534,8 +532,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface.Address.String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address.String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
err := e.wgInterface.UpdateAddr(conf.Address) err := e.wgInterface.UpdateAddr(conf.Address)
if err != nil { if err != nil {

View File

@ -857,7 +857,7 @@ loop:
} }
// cleanup test // cleanup test
for n, peerEngine := range engines { for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name, n) t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close() errStop := peerEngine.mgmClient.Close()
if errStop != nil { if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop) log.Infoln("got error trying to close management clients from engine: ", errStop)
@ -905,7 +905,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP}, expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP},
}, },
{ {
name: "Only Interface Name Should Return Nil", name: "Only Interface name Should Return Nil",
inputBlacklistInterface: defaultInterfaceBlacklist, inputBlacklistInterface: defaultInterfaceBlacklist,
inputMapList: []string{testingInterface}, inputMapList: []string{testingInterface},
expectedOutput: nil, expectedOutput: nil,

View File

@ -162,7 +162,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if err != nil { if err != nil {
return err return err
} }
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String()) err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
if err != nil { if err != nil {
return fmt.Errorf("couldn't remove route %s from system, err: %v", return fmt.Errorf("couldn't remove route %s from system, err: %v",
c.network, err) c.network, err)
@ -201,10 +201,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return err return err
} }
} else { } else {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String()) err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
if err != nil { if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.GetAddress().IP.String(), err) c.network.String(), c.wgInterface.Address().IP.String(), err)
} }
} }

View File

@ -40,7 +40,7 @@ func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
default: default:
m.serverRouter.mux.Lock() m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock() defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil { if err != nil {
return err return err
} }
@ -57,7 +57,7 @@ func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
default: default:
m.serverRouter.mux.Lock() m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock() defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil { if err != nil {
return err return err
} }

View File

@ -39,18 +39,18 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String()) err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
} else { } else {
require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface") require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
} }
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String()) err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)

29
iface/address.go Normal file
View File

@ -0,0 +1,29 @@
package iface
import (
"fmt"
"net"
)
// WGAddress Wireguard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func parseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
}
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@ -1,258 +0,0 @@
package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"time"
)
// GetName returns the interface name
func (w *WGIface) GetName() string {
return w.Name
}
// GetAddress returns the interface address
func (w *WGIface) GetAddress() WGAddress {
return w.Address
}
// configureDevice configures the wireguard device
func (w *WGIface) configureDevice(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(w.Name)
if err != nil {
return err
}
log.Debugf("got Wireguard device %s", w.Name)
return wg.ConfigureDevice(w.Name, config)
}
// Configure configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Configure(privateKey string, port int) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("configuring Wireguard interface %s", w.Name)
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := 0
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while configuring interface %s with port %d", err, w.Name, port)
}
return nil
}
// GetListenPort returns the listening port of the Wireguard endpoint
func (w *WGIface) GetListenPort() (*int, error) {
log.Debugf("getting Wireguard listen port of interface %s", w.Name)
//discover Wireguard current configuration
wg, err := wgctrl.New()
if err != nil {
return nil, err
}
defer wg.Close()
d, err := wg.Device(w.Name)
if err != nil {
return nil, err
}
log.Debugf("got Wireguard device listen port %s, %d", w.Name, d.ListenPort)
return &d.ListenPort, nil
}
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.Name, peerKey, endpoint)
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while updating peer on interface %s with settings: allowed ips %s, endpoint %s", err, w.Name, allowedIps, endpoint.String())
}
return nil
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
}
return nil
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
existingPeer, err := getPeer(w.Name, peerKey)
if err != nil {
return err
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
break
}
}
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
}
return nil
}
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.Name)
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while removing peer %s from interface %s", err, peerKey, w.Name)
}
return nil
}

View File

@ -3,9 +3,12 @@ package iface
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"runtime"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
const ( const (
@ -13,83 +16,276 @@ const (
DefaultWgPort = 51820 DefaultWgPort = 51820
) )
// WGIface represents a interface instance
type WGIface struct {
Name string
Port int
MTU int
Address WGAddress
Interface NetInterface
mu sync.Mutex
}
// WGAddress Wireguard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
func (addr *WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}
// NetInterface represents a generic network tunnel interface // NetInterface represents a generic network tunnel interface
type NetInterface interface { type NetInterface interface {
Close() error Close() error
} }
// WGIface represents a interface instance
type WGIface struct {
name string
address WGAddress
mtu int
netInterface NetInterface
mu sync.Mutex
}
// NewWGIFace Creates a new Wireguard interface instance // NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) { func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) {
wgIface := &WGIface{ wgIface := &WGIface{
Name: iface, name: iface,
MTU: mtu, mtu: mtu,
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
wgAddress, err := parseAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return wgIface, err return wgIface, err
} }
wgIface.Address = wgAddress wgIface.address = wgAddress
return wgIface, nil return wgIface, nil
} }
// parseAddress parse a string ("1.2.3.4/24") address to WG Address // Name returns the interface name
func parseAddress(address string) (WGAddress, error) { func (w *WGIface) Name() string {
ip, network, err := net.ParseCIDR(address) return w.name
if err != nil {
return WGAddress{}, err
}
return WGAddress{
IP: ip,
Network: network,
}, nil
} }
// Close closes the tunnel interface // Address returns the interface address
func (w *WGIface) Close() error { func (w *WGIface) Address() WGAddress {
return w.address
}
// Configure configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Configure(privateKey string, port int) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.Interface == nil {
return nil log.Debugf("configuring Wireguard interface %s", w.name)
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
} }
err := w.Interface.Close() fwmark := 0
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, w.name, port)
}
return nil
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := parseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }
if runtime.GOOS != "windows" { w.address = addr
sockPath := "/var/run/wireguard/" + w.Name + ".sock" return w.assignAddr()
if _, statErr := os.Stat(sockPath); statErr == nil { }
statErr = os.Remove(sockPath)
if statErr != nil { // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
return statErr // Endpoint is optional
} func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.name, peerKey, endpoint)
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, w.name, allowedIps, endpoint.String())
}
return nil
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP)
}
return nil
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
existingPeer, err := getPeer(w.name, peerKey)
if err != nil {
return err
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
break
} }
} }
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP)
}
return nil return nil
} }
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.name)
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, w.name)
}
return nil
}
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
// configureDevice configures the wireguard device
func (w *WGIface) configureDevice(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(w.name)
if err != nil {
return err
}
log.Debugf("got Wireguard device %s", w.name)
return wg.ConfigureDevice(w.name, config)
}

View File

@ -1,8 +1,9 @@
package iface package iface
import ( import (
log "github.com/sirupsen/logrus"
"os/exec" "os/exec"
log "github.com/sirupsen/logrus"
) )
// Create Creates a new Wireguard interface, sets a given IP and brings it up. // Create Creates a new Wireguard interface, sets a given IP and brings it up.
@ -15,26 +16,17 @@ func (w *WGIface) Create() error {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (w *WGIface) assignAddr() error { func (w *WGIface) assignAddr() error {
//mask,_ := w.Address.Network.Mask.Size() cmd := exec.Command("ifconfig", w.name, "inet", w.address.IP.String(), w.address.IP.String())
//
//address := fmt.Sprintf("%s/%d",w.Address.IP.String() , mask)
cmd := exec.Command("ifconfig", w.Name, "inet", w.Address.IP.String(), w.Address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil { if out, err := cmd.CombinedOutput(); err != nil {
log.Infof("adding addreess command \"%v\" failed with output %s and error: ", cmd.String(), out) log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out)
return err return err
} }
routeCmd := exec.Command("route", "add", "-net", w.Address.Network.String(), "-interface", w.Name) routeCmd := exec.Command("route", "add", "-net", w.address.Network.String(), "-interface", w.name)
if out, err := routeCmd.CombinedOutput(); err != nil { if out, err := routeCmd.CombinedOutput(); err != nil {
log.Printf("adding route command \"%v\" failed with output %s and error: ", routeCmd.String(), out) log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out)
return err return err
} }
return nil return nil
} }
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
return false
}

View File

@ -2,15 +2,12 @@ package iface
import ( import (
"fmt" "fmt"
"os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"os"
) )
type NativeLink struct {
Link *netlink.Link
}
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
@ -33,10 +30,10 @@ func (w *WGIface) Create() error {
// Works for Linux and offers much better network performance // Works for Linux and offers much better network performance
func (w *WGIface) createWithKernel() error { func (w *WGIface) createWithKernel() error {
link := newWGLink(w.Name) link := newWGLink(w.name)
// check if interface exists // check if interface exists
l, err := netlink.LinkByName(w.Name) l, err := netlink.LinkByName(w.name)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case netlink.LinkNotFoundError: case netlink.LinkNotFoundError:
@ -54,15 +51,15 @@ func (w *WGIface) createWithKernel() error {
} }
} }
log.Debugf("adding device: %s", w.Name) log.Debugf("adding device: %s", w.name)
err = netlink.LinkAdd(link) err = netlink.LinkAdd(link)
if os.IsExist(err) { if os.IsExist(err) {
log.Infof("interface %s already exists. Will reuse.", w.Name) log.Infof("interface %s already exists. Will reuse.", w.name)
} else if err != nil { } else if err != nil {
return err return err
} }
w.Interface = link w.netInterface = link
err = w.assignAddr() err = w.assignAddr()
if err != nil { if err != nil {
@ -70,17 +67,17 @@ func (w *WGIface) createWithKernel() error {
} }
// todo do a discovery // todo do a discovery
log.Debugf("setting MTU: %d interface: %s", w.MTU, w.Name) log.Debugf("setting MTU: %d interface: %s", w.mtu, w.name)
err = netlink.LinkSetMTU(link, w.MTU) err = netlink.LinkSetMTU(link, w.mtu)
if err != nil { if err != nil {
log.Errorf("error setting MTU on interface: %s", w.Name) log.Errorf("error setting MTU on interface: %s", w.name)
return err return err
} }
log.Debugf("bringing up interface: %s", w.Name) log.Debugf("bringing up interface: %s", w.name)
err = netlink.LinkSetUp(link) err = netlink.LinkSetUp(link)
if err != nil { if err != nil {
log.Errorf("error bringing up interface: %s", w.Name) log.Errorf("error bringing up interface: %s", w.name)
return err return err
} }
@ -89,7 +86,7 @@ func (w *WGIface) createWithKernel() error {
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (w *WGIface) assignAddr() error { func (w *WGIface) assignAddr() error {
link := newWGLink(w.Name) link := newWGLink(w.name)
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(link, 0) list, err := netlink.AddrList(link, 0)
@ -105,11 +102,11 @@ func (w *WGIface) assignAddr() error {
} }
} }
log.Debugf("adding address %s to interface: %s", w.Address.String(), w.Name) log.Debugf("adding address %s to interface: %s", w.address.String(), w.name)
addr, _ := netlink.ParseAddr(w.Address.String()) addr, _ := netlink.ParseAddr(w.address.String())
err = netlink.AddrAdd(link, addr) err = netlink.AddrAdd(link, addr)
if os.IsExist(err) { if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", w.Name, w.Address.String()) log.Infof("interface %s already has the address: %s", w.name, w.address.String())
} else if err != nil { } else if err != nil {
return err return err
} }

View File

@ -46,11 +46,11 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Error(err) t.Error(err)
} }
}() }()
port, err := iface.GetListenPort() port, err := getListenPortByName(ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface.Configure(key, *port) err = iface.Configure(key, port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -164,11 +164,11 @@ func Test_ConfigureInterface(t *testing.T) {
} }
}() }()
port, err := iface.GetListenPort() port, err := getListenPortByName(ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface.Configure(key, *port) err = iface.Configure(key, port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -210,11 +210,11 @@ func Test_UpdatePeer(t *testing.T) {
t.Error(err) t.Error(err)
} }
}() }()
port, err := iface.GetListenPort() port, err := getListenPortByName(ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface.Configure(key, *port) err = iface.Configure(key, port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -269,11 +269,11 @@ func Test_RemovePeer(t *testing.T) {
t.Error(err) t.Error(err)
} }
}() }()
port, err := iface.GetListenPort() port, err := getListenPortByName(ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface.Configure(key, *port) err = iface.Configure(key, port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -298,12 +298,10 @@ func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
peer1wgIP := "10.99.99.17/30" peer1wgIP := "10.99.99.17/30"
peer1Key, _ := wgtypes.GeneratePrivateKey() peer1Key, _ := wgtypes.GeneratePrivateKey()
//peer1Port := WgPort + 4
peer2ifaceName := fmt.Sprintf("utun%d", 500) peer2ifaceName := "utun500"
peer2wgIP := "10.99.99.18/30" peer2wgIP := "10.99.99.18/30"
peer2Key, _ := wgtypes.GeneratePrivateKey() peer2Key, _ := wgtypes.GeneratePrivateKey()
//peer2Port := WgPort + 5
keepAlive := 1 * time.Second keepAlive := 1 * time.Second
@ -315,11 +313,11 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer1Port, err := iface1.GetListenPort() peer1Port, err := getListenPortByName(peer1ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", *peer1Port)) peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1Port))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -332,11 +330,11 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer2Port, err := iface2.GetListenPort() peer2Port, err := getListenPortByName(peer2ifaceName)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", *peer2Port)) peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2Port))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -351,11 +349,11 @@ func Test_ConnectPeers(t *testing.T) {
} }
}() }()
err = iface1.Configure(peer1Key.String(), *peer1Port) err = iface1.Configure(peer1Key.String(), peer1Port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = iface2.Configure(peer2Key.String(), *peer2Port) err = iface2.Configure(peer2Key.String(), peer2Port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -388,3 +386,18 @@ func Test_ConnectPeers(t *testing.T) {
} }
} }
func getListenPortByName(name string) (int, error) {
wg, err := wgctrl.New()
if err != nil {
return 0, err
}
defer wg.Close()
d, err := wg.Device(name)
if err != nil {
return 0, err
}
return d.ListenPort, nil
}

View File

@ -4,23 +4,53 @@
package iface package iface
import ( import (
"net"
"os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"net"
) )
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) createWithUserspace() error { func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return "", nil
}
tunIface, err := tun.CreateTUN(w.Name, w.MTU) // Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.netInterface == nil {
return nil
}
err := w.netInterface.Close()
if err != nil { if err != nil {
return err return err
} }
w.Interface = tunIface sockPath := "/var/run/wireguard/" + w.name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath)
if statErr != nil {
return statErr
}
}
return nil
}
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (w *WGIface) createWithUserspace() error {
tunIface, err := tun.CreateTUN(w.name, w.mtu)
if err != nil {
return err
}
w.netInterface = tunIface
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
@ -28,7 +58,7 @@ func (w *WGIface) createWithUserspace() error {
if err != nil { if err != nil {
return err return err
} }
uapi, err := getUAPI(w.Name) uapi, err := getUAPI(w.name)
if err != nil { if err != nil {
return err return err
} }
@ -61,22 +91,3 @@ func getUAPI(iface string) (net.Listener, error) {
} }
return ipc.UAPIListen(iface, tunSock) return ipc.UAPIListen(iface, tunSock)
} }
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := parseAddress(newAddr)
if err != nil {
return err
}
w.Address = addr
return w.assignAddr()
}
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return "", nil
}

View File

@ -2,11 +2,11 @@ package iface
import ( import (
"fmt" "fmt"
"net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
) )
// Create Creates a new Wireguard interface, sets a given IP and brings it up. // Create Creates a new Wireguard interface, sets a given IP and brings it up.
@ -15,55 +15,27 @@ func (w *WGIface) Create() error {
defer w.mu.Unlock() defer w.mu.Unlock()
WintunStaticRequestedGUID, _ := windows.GenerateGUID() WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(w.Name, "WireGuard", &WintunStaticRequestedGUID) adapter, err := driver.CreateAdapter(w.name, "WireGuard", &WintunStaticRequestedGUID)
if err != nil { if err != nil {
err = fmt.Errorf("error creating adapter: %w", err) err = fmt.Errorf("error creating adapter: %w", err)
return err return err
} }
w.Interface = adapter w.netInterface = adapter
luid := adapter.LUID()
err = adapter.SetAdapterState(driver.AdapterStateUp) err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil { if err != nil {
return err return err
} }
state, _ := luid.GUID() state, _ := adapter.LUID().GUID()
log.Debugln("device guid: ", state.String()) log.Debugln("device guid: ", state.String())
return w.assignAddr(luid) return w.assignAddr()
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (w *WGIface) assignAddr(luid winipcfg.LUID) error {
log.Debugf("adding address %s to interface: %s", w.Address.IP, w.Name)
err := luid.SetIPAddresses([]net.IPNet{{w.Address.IP, w.Address.Network.Mask}})
if err != nil {
return err
}
return nil
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
luid := w.Interface.(*driver.Adapter).LUID()
addr, err := parseAddress(newAddr)
if err != nil {
return err
}
w.Address = addr
return w.assignAddr(luid)
} }
// GetInterfaceGUIDString returns an interface GUID string // GetInterfaceGUIDString returns an interface GUID string
func (w *WGIface) GetInterfaceGUIDString() (string, error) { func (w *WGIface) GetInterfaceGUIDString() (string, error) {
if w.Interface == nil { if w.netInterface == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")
} }
windowsDevice := w.Interface.(*driver.Adapter) windowsDevice := w.netInterface.(*driver.Adapter)
luid := windowsDevice.LUID() luid := windowsDevice.LUID()
guid, err := luid.GUID() guid, err := luid.GUID()
if err != nil { if err != nil {
@ -72,7 +44,26 @@ func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return guid.String(), nil return guid.String(), nil
} }
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // Close closes the tunnel interface
func WireguardModuleIsLoaded() bool { func (w *WGIface) Close() error {
return false w.mu.Lock()
defer w.mu.Unlock()
if w.netInterface == nil {
return nil
}
return w.netInterface.Close()
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (w *WGIface) assignAddr() error {
luid := w.netInterface.(*driver.Adapter).LUID()
log.Debugf("adding address %s to interface: %s", w.address.IP, w.name)
err := luid.SetIPAddresses([]net.IPNet{{w.address.IP, w.address.Network.Mask}})
if err != nil {
return err
}
return nil
} }

9
iface/module.go Normal file
View File

@ -0,0 +1,9 @@
//go:build !linux
// +build !linux
package iface
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
return false
}