diff --git a/.github/workflows/golang-test.yml b/.github/workflows/golang-test.yml index 0e1ee94f8..36f929de7 100644 --- a/.github/workflows/golang-test.yml +++ b/.github/workflows/golang-test.yml @@ -18,7 +18,7 @@ jobs: - name: Checkout code uses: actions/checkout@v2 - name: Test - run: GOBIN=$(which go) && sudo --preserve-env=GOROOT $GOBIN test ./... + run: GOBIN=$(which go) && sudo --preserve-env=GOROOT $GOBIN test -p 1 ./... test_build: strategy: diff --git a/.gitignore b/.gitignore index 29b636a48..c181e07b2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea -*.iml \ No newline at end of file +*.iml +dist/ \ No newline at end of file diff --git a/cmd/service_test.go b/cmd/service_test.go index 622c4566e..af9c2a3e1 100644 --- a/cmd/service_test.go +++ b/cmd/service_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/kardianos/service" "io/ioutil" + "os" "testing" ) @@ -51,6 +52,13 @@ func Test_ServiceStartCMD(t *testing.T) { } func Test_ServiceRunCMD(t *testing.T) { + configFilePath := "/tmp/config.json" + if _, err := os.Stat(configFilePath); err == nil { + e := os.Remove(configFilePath) + if e != nil { + t.Fatal(err) + } + } rootCmd.SetArgs([]string{ "init", "--stunURLs", @@ -64,7 +72,7 @@ func Test_ServiceRunCMD(t *testing.T) { "--wgLocalAddr", "10.100.100.1/24", "--config", - "/tmp/config.json", + configFilePath, }) err := rootCmd.Execute() if err != nil { diff --git a/connection/connection.go b/connection/connection.go index 7a95c766d..76bbb56ab 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -16,6 +16,14 @@ var ( DefaultWgKeepAlive = 20 * time.Second ) +type Status string + +const ( + StatusConnected Status = "Connected" + StatusConnecting Status = "Connecting" + StatusDisconnected Status = "Disconnected" +) + // ConnConfig Connection configuration struct type ConnConfig struct { // Local Wireguard listening address e.g. 127.0.0.1:51820 @@ -66,6 +74,8 @@ type Connection struct { closeCond *Cond remoteAuthCond sync.Once + + Status Status } // NewConnection Creates a new connection and sets handling functions for signal protocol @@ -85,6 +95,7 @@ func NewConnection(config ConnConfig, connected: NewCond(), agent: nil, wgProxy: NewWgProxy(config.WgIface, config.RemoteWgKey.String(), config.WgAllowedIPs, config.WgListenAddr), + Status: StatusDisconnected, } } @@ -126,6 +137,7 @@ func (conn *Connection) Open(timeout time.Duration) error { return err } + conn.Status = StatusConnecting log.Infof("trying to connect to peer %s", conn.Config.RemoteWgKey.String()) // wait until credentials have been sent from the remote peer (will arrive via a signal server) @@ -164,17 +176,23 @@ func (conn *Connection) Open(timeout time.Duration) error { } } + conn.Status = StatusConnected log.Infof("opened connection to peer %s", conn.Config.RemoteWgKey.String()) + case <-conn.closeCond.C: + conn.Status = StatusDisconnected + return fmt.Errorf("connection to peer %s has been closed", conn.Config.RemoteWgKey.String()) case <-time.After(timeout): err := conn.Close() if err != nil { log.Warnf("error while closing connection to peer %s -> %s", conn.Config.RemoteWgKey.String(), err.Error()) } + conn.Status = StatusDisconnected return fmt.Errorf("timeout of %vs exceeded while waiting for the remote peer %s", timeout.Seconds(), conn.Config.RemoteWgKey.String()) } // wait until connection has been closed <-conn.closeCond.C + conn.Status = StatusDisconnected return fmt.Errorf("connection to peer %s has been closed", conn.Config.RemoteWgKey.String()) } diff --git a/connection/engine.go b/connection/engine.go index 80928e0a3..57592db5c 100644 --- a/connection/engine.go +++ b/connection/engine.go @@ -9,9 +9,14 @@ import ( "github.com/wiretrustee/wiretrustee/signal" sProto "github.com/wiretrustee/wiretrustee/signal/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "sync" "time" ) +// PeerConnectionTimeout is a timeout of an initial connection attempt to a remote peer. +// E.g. this peer will wait PeerConnectionTimeout for the remote peer to respond, if not successful then it will retry the connection attempt. +const PeerConnectionTimeout = 60 * time.Second + // Engine is an instance of the Connection Engine type Engine struct { // a list of STUN and TURN servers @@ -26,6 +31,8 @@ type Engine struct { wgIP string // Network Interfaces to ignore iFaceBlackList map[string]struct{} + // PeerMux is used to sync peer operations (e.g. open connection, peer removal) + PeerMux *sync.Mutex } // Peer is an instance of the Connection Peer @@ -44,6 +51,7 @@ func NewEngine(signal *signal.Client, stunsTurns []*ice.URL, wgIface string, wgA wgIP: wgAddr, conns: map[string]*Connection{}, iFaceBlackList: iFaceBlackList, + PeerMux: &sync.Mutex{}, } } @@ -71,42 +79,76 @@ func (e *Engine) Start(myKey wgtypes.Key, peers []Peer) error { e.receiveSignal() - // initialize peer agents for _, peer := range peers { - peer := peer - go func() { - var backOff = &backoff.ExponentialBackOff{ - InitialInterval: backoff.DefaultInitialInterval, - RandomizationFactor: backoff.DefaultRandomizationFactor, - Multiplier: backoff.DefaultMultiplier, - MaxInterval: 5 * time.Second, - MaxElapsedTime: time.Duration(0), //never stop - Stop: backoff.Stop, - Clock: backoff.SystemClock, - } - operation := func() error { - _, err := e.openPeerConnection(*wgPort, myKey, peer) - if err != nil { - log.Warnln("retrying connection because of error: ", err.Error()) - e.conns[peer.WgPubKey] = nil - return err - } - backOff.Reset() - return nil - } - - err = backoff.Retry(operation, backOff) - if err != nil { - // should actually never happen - panic(err) - } - }() + go e.InitializePeer(*wgPort, myKey, peer) } return nil } +// InitializePeer peer agent attempt to open connection +func (e *Engine) InitializePeer(wgPort int, myKey wgtypes.Key, peer Peer) { + var backOff = &backoff.ExponentialBackOff{ + InitialInterval: backoff.DefaultInitialInterval, + RandomizationFactor: backoff.DefaultRandomizationFactor, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: 5 * time.Second, + MaxElapsedTime: time.Duration(0), //never stop + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + operation := func() error { + _, err := e.openPeerConnection(wgPort, myKey, peer) + e.PeerMux.Lock() + defer e.PeerMux.Unlock() + if _, ok := e.conns[peer.WgPubKey]; !ok { + log.Infof("removing connection attempt with Peer: %v, not retrying", peer.WgPubKey) + return nil + } + + if err != nil { + log.Warnln(err) + log.Warnln("retrying connection because of error: ", err.Error()) + return err + } + return nil + } + + err := backoff.Retry(operation, backOff) + if err != nil { + // should actually never happen + panic(err) + } +} + +// RemovePeerConnection closes existing peer connection and removes peer +func (e *Engine) RemovePeerConnection(peer Peer) error { + e.PeerMux.Lock() + defer e.PeerMux.Unlock() + conn, exists := e.conns[peer.WgPubKey] + if exists && conn != nil { + delete(e.conns, peer.WgPubKey) + return conn.Close() + } + return nil +} + +// GetPeerConnectionStatus returns a connection Status or nil if peer connection wasn't found +func (e *Engine) GetPeerConnectionStatus(peerKey string) *Status { + e.PeerMux.Lock() + defer e.PeerMux.Unlock() + + conn, exists := e.conns[peerKey] + if exists && conn != nil { + return &conn.Status + } + + return nil +} + +// opens a new peer connection func (e *Engine) openPeerConnection(wgPort int, myKey wgtypes.Key, peer Peer) (*Connection, error) { + e.PeerMux.Lock() remoteKey, _ := wgtypes.ParseKey(peer.WgPubKey) connConfig := &ConnConfig{ @@ -130,11 +172,12 @@ func (e *Engine) openPeerConnection(wgPort int, myKey wgtypes.Key, peer Peer) (* signalCandidate := func(candidate ice.Candidate) error { return signalCandidate(candidate, myKey, remoteKey, e.signal) } - conn := NewConnection(*connConfig, signalCandidate, signalOffer, signalAnswer) e.conns[remoteKey.String()] = conn + e.PeerMux.Unlock() + // blocks until the connection is open (or timeout) - err := conn.Open(60 * time.Second) + err := conn.Open(PeerConnectionTimeout) if err != nil { return nil, err } diff --git a/connection/engine_test.go b/connection/engine_test.go new file mode 100644 index 000000000..07c39da83 --- /dev/null +++ b/connection/engine_test.go @@ -0,0 +1,163 @@ +package connection + +import ( + "context" + "fmt" + ice "github.com/pion/ice/v2" + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/iface" + sig "github.com/wiretrustee/wiretrustee/signal" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "testing" + "time" +) + +var engine *Engine +var testKey wgtypes.Key +var testPeer Peer + +const ifaceName = "utun9991" + +func Test_Start(t *testing.T) { + level, _ := log.ParseLevel("Debug") + log.SetLevel(level) + + var err error + testKey, err = wgtypes.GenerateKey() + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + iceUrl, err := ice.ParseURL("stun:stun.wiretrustee.com:3468") + if err != nil { + t.Fatal(err) + } + var stunURLs = []*ice.URL{iceUrl} + + iFaceBlackList := make(map[string]struct{}) + + signalClient, err := sig.NewClient(ctx, "signal.wiretrustee.com:10000", testKey) + if err != nil { + t.Fatal(err) + } + + engine = NewEngine(signalClient, stunURLs, ifaceName, "10.99.91.1/24", iFaceBlackList) + + var emptyPeer []Peer + err = engine.Start(testKey, emptyPeer) + if err != nil { + t.Fatal(err) + } + wg, err := wgctrl.New() + if err != nil { + t.Fatal(err) + } + defer wg.Close() + + _, err = wg.Device(ifaceName) + if err != nil { + t.Fatal(err) + } +} + +func TestEngine_InitializePeerWithoutRemote(t *testing.T) { + tmpKey, err := wgtypes.GenerateKey() + if err != nil { + t.Fatal(err) + } + testPeer = Peer{ + tmpKey.PublicKey().String(), + "10.99.91.2/32", + } + go engine.InitializePeer(iface.WgPort, testKey, testPeer) + // Let the connections initialize + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + for { + status := engine.GetPeerConnectionStatus(testPeer.WgPubKey) + err = ctx.Err() + if (status != nil && *status == StatusConnecting) || err != nil { + if err != nil { + t.Fatal(err) + } + //success + break + } + } +} + +func TestEngine_Initialize2PeersWithoutRemote(t *testing.T) { + tmpKey1, err := wgtypes.GenerateKey() + if err != nil { + t.Fatal(err) + } + tmpKey2, err := wgtypes.GenerateKey() + if err != nil { + t.Fatal(err) + } + testPeer1 := Peer{ + tmpKey1.PublicKey().String(), + "10.99.91.2/32", + } + testPeer2 := Peer{ + tmpKey2.PublicKey().String(), + "10.99.91.3/32", + } + go engine.InitializePeer(iface.WgPort, testKey, testPeer1) + go engine.InitializePeer(iface.WgPort, testKey, testPeer2) + // Let the connections initialize + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + for { + status1 := engine.GetPeerConnectionStatus(testPeer1.WgPubKey) + status2 := engine.GetPeerConnectionStatus(testPeer2.WgPubKey) + err = ctx.Err() + if (status1 != nil && status2 != nil) || err != nil { + if err != nil { + t.Fatal(err) + } + if *status1 == StatusConnecting && *status2 == StatusConnecting { + //success + break + } + } + } +} + +func TestEngine_RemovePeerConnectionWithoutRemote(t *testing.T) { + + // Let the connections initialize + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + for { + status := engine.GetPeerConnectionStatus(testPeer.WgPubKey) + err := ctx.Err() + if (status != nil && *status == StatusConnecting) || err != nil { + if err != nil { + t.Fatal(err) + } + break + } + } + + // Let the connections close + err := engine.RemovePeerConnection(testPeer) + if err != nil { + t.Fatal(err) + } + + status := engine.GetPeerConnectionStatus(testPeer.WgPubKey) + if status != nil { + t.Fatal(fmt.Errorf("wrong status %v", status)) + } +} + +func Test_CloseInterface(t *testing.T) { + err := iface.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/connection/wgproxy.go b/connection/wgproxy.go index 489b421cb..153aa0824 100644 --- a/connection/wgproxy.go +++ b/connection/wgproxy.go @@ -38,10 +38,15 @@ func (p *WgProxy) Close() error { return err } } + err := iface.RemovePeer(p.iface, p.remoteKey) + if err != nil { + return err + } return nil } +// StartLocal configure the interface with a peer using a direct IP:Port endpoint to the remote host func (p *WgProxy) StartLocal(host string) error { err := iface.UpdatePeer(p.iface, p.remoteKey, p.allowedIps, DefaultWgKeepAlive, host) if err != nil { diff --git a/iface/iface.go b/iface/iface.go index 774c6d0b6..4f54cf657 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -41,8 +41,8 @@ func CreateWithUserspace(iface string, address string) error { for { uapiConn, err := uapi.Accept() if err != nil { - log.Debugln(err) - return + log.Debugln("uapi Accept failed with error: ", err) + continue } go tunDevice.IpcHandle(uapiConn) } @@ -57,13 +57,21 @@ func CreateWithUserspace(iface string, address string) error { return nil } -// ConfigureWithKeyGen Extends the functionality of Configure(iface string, privateKey string) by generating a new Wireguard private key -func ConfigureWithKeyGen(iface string) (*wgtypes.Key, error) { - key, err := wgtypes.GeneratePrivateKey() +// configure peer for the wireguard device +func configureDevice(iface string, config wgtypes.Config) error { + wg, err := wgctrl.New() if err != nil { - return nil, err + return err } - return &key, Configure(iface, key.String()) + defer wg.Close() + + _, err = wg.Device(iface) + if err != nil { + return err + } + log.Debugf("got Wireguard device %s", iface) + + return wg.ConfigureDevice(iface, config) } // Configure configures a Wireguard interface @@ -71,11 +79,6 @@ func ConfigureWithKeyGen(iface string) (*wgtypes.Key, error) { func Configure(iface string, privateKey string) error { log.Debugf("configuring Wireguard interface %s", iface) - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) @@ -84,18 +87,14 @@ func Configure(iface string, privateKey string) error { } fwmark := 0 p := WgPort - cfg := wgtypes.Config{ + config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: false, FirewallMark: &fwmark, ListenPort: &p, } - err = wg.ConfigureDevice(iface, cfg) - if err != nil { - return err - } - return nil + return configureDevice(iface, config) } // GetListenPort returns the listening port of the Wireguard endpoint @@ -118,55 +117,12 @@ func GetListenPort(iface string) (*int, error) { return &d.ListenPort, nil } -// UpdateListenPort updates a Wireguard interface listen port -func UpdateListenPort(iface string, newPort int) error { - log.Debugf("updating Wireguard listen port of interface %s, new port %d", iface, newPort) - - //discover Wireguard current configuration - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() - - _, err = wg.Device(iface) - if err != nil { - return err - } - log.Debugf("got Wireguard device %s", iface) - - config := wgtypes.Config{ - ListenPort: &newPort, - ReplacePeers: false, - } - err = wg.ConfigureDevice(iface, config) - if err != nil { - return err - } - - log.Debugf("updated Wireguard listen port of interface %s, new port %d", iface, newPort) - - return nil -} - // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // Endpoint is optional func UpdatePeer(iface string, peerKey string, allowedIps string, keepAlive time.Duration, endpoint string) error { log.Debugf("updating interface %s peer %s: endpoint %s ", iface, peerKey, endpoint) - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() - - _, err = wg.Device(iface) - if err != nil { - return err - } - log.Debugf("got Wireguard device %s", iface) - //parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -177,20 +133,18 @@ func UpdatePeer(iface string, peerKey string, allowedIps string, keepAlive time. if err != nil { return err } - peers := make([]wgtypes.PeerConfig, 0) peer := wgtypes.PeerConfig{ PublicKey: peerKeyParsed, ReplaceAllowedIPs: true, AllowedIPs: []net.IPNet{*ipNet}, PersistentKeepaliveInterval: &keepAlive, } - peers = append(peers, peer) config := wgtypes.Config{ - ReplacePeers: false, - Peers: peers, + Peers: []wgtypes.PeerConfig{peer}, } - err = wg.ConfigureDevice(iface, config) + + err = configureDevice(iface, config) if err != nil { return err } @@ -208,18 +162,6 @@ func UpdatePeerEndpoint(iface string, peerKey string, newEndpoint string) error log.Debugf("updating peer %s endpoint %s ", peerKey, newEndpoint) - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() - - _, err = wg.Device(iface) - if err != nil { - return err - } - log.Debugf("got Wireguard device %s", iface) - peerAddr, err := net.ResolveUDPAddr("udp4", newEndpoint) if err != nil { return err @@ -231,23 +173,41 @@ func UpdatePeerEndpoint(iface string, peerKey string, newEndpoint string) error if err != nil { return err } - peers := make([]wgtypes.PeerConfig, 0) + peer := wgtypes.PeerConfig{ PublicKey: peerKeyParsed, ReplaceAllowedIPs: false, UpdateOnly: true, Endpoint: peerAddr, } - peers = append(peers, peer) - config := wgtypes.Config{ - ReplacePeers: false, - Peers: peers, + Peers: []wgtypes.PeerConfig{peer}, } - err = wg.ConfigureDevice(iface, config) + return configureDevice(iface, config) +} + +// RemovePeer removes a Wireguard Peer from the interface iface +func RemovePeer(iface string, peerKey string) error { + log.Debugf("Removing peer %s from interface %s ", peerKey, iface) + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err } - return nil + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + return configureDevice(iface, config) +} + +// Closes the User Space tunnel interface +func CloseWithUserspace() error { + return tunIface.Close() } diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index c3c86050b..85a8825e2 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -37,3 +37,8 @@ func addRoute(iface string, ipNet *net.IPNet) error { } return nil } + +// Closes the tunnel interface +func Close() error { + return CloseWithUserspace() +} diff --git a/iface/iface_linux.go b/iface/iface_linux.go index 86c670785..05116380c 100644 --- a/iface/iface_linux.go +++ b/iface/iface_linux.go @@ -1,8 +1,10 @@ package iface import ( + "fmt" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" "os" ) @@ -36,14 +38,6 @@ func CreateWithKernel(iface string, address string) error { return err } - log.Debugf("adding address %s to interface: %s", address, iface) - addr, _ := netlink.ParseAddr(address) - err = netlink.AddrAdd(&link, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", iface, address) - } else if err != nil { - return err - } err = assignAddr(address, iface) if err != nil { return err @@ -103,3 +97,38 @@ func (w *wgLink) Attrs() *netlink.LinkAttrs { func (w *wgLink) Type() string { return "wireguard" } + +// Closes the tunnel interface +func Close() error { + + if tunIface != nil { + return CloseWithUserspace() + } else { + var iface = "" + wg, err := wgctrl.New() + if err != nil { + return err + } + defer wg.Close() + devList, err := wg.Devices() + if err != nil { + return err + } + for _, wgDev := range devList { + if wgDev.ListenPort == WgPort { + iface = wgDev.Name + break + } + } + if iface == "" { + return fmt.Errorf("Wireguard Interface not found") + } + attrs := netlink.NewLinkAttrs() + attrs.Name = iface + + link := wgLink{ + attrs: &attrs, + } + return netlink.LinkDel(&link) + } +} diff --git a/iface/iface_test.go b/iface/iface_test.go new file mode 100644 index 000000000..b21f41e57 --- /dev/null +++ b/iface/iface_test.go @@ -0,0 +1,148 @@ +package iface + +import ( + "fmt" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "net" + "testing" + "time" +) + +// keep darwin compability +const ( + ifaceName = "utun999" + key = "0PMI6OkB5JmB+Jj/iWWHekuQRx+bipZirWCWKFXexHc=" + peerPubKey = "Ok0mC0qlJyXEPKh2UFIpsI2jG0L7LRpC3sLAusSJ5CQ=" +) + +func Test_CreateInterface(t *testing.T) { + level, _ := log.ParseLevel("Debug") + log.SetLevel(level) + wgIP := "10.99.99.1/24" + err := Create(ifaceName, wgIP) + if err != nil { + t.Fatal(err) + } + wg, err := wgctrl.New() + if err != nil { + t.Fatal(err) + } + defer wg.Close() + + _, err = wg.Device(ifaceName) + if err != nil { + t.Fatal(err) + } +} + +func Test_ConfigureInterface(t *testing.T) { + err := Configure(ifaceName, key) + if err != nil { + t.Fatal(err) + } + + wg, err := wgctrl.New() + if err != nil { + t.Fatal(err) + } + defer wg.Close() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + t.Fatal(err) + } + if wgDevice.PrivateKey.String() != key { + t.Fatalf("Private keys don't match after configure: %s != %s", key, wgDevice.PrivateKey.String()) + } +} + +func Test_UpdatePeer(t *testing.T) { + keepAlive := 15 * time.Second + allowedIP := "10.99.99.2/32" + endpoint := "127.0.0.1:9900" + err := UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint) + if err != nil { + t.Fatal(err) + } + peer, err := getPeer() + if err != nil { + t.Fatal(err) + } + if peer.PersistentKeepaliveInterval != keepAlive { + t.Fatal("configured peer with mismatched keepalive interval value") + } + + resolvedEndpoint, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + t.Fatal(err) + } + if peer.Endpoint.String() != resolvedEndpoint.String() { + t.Fatal("configured peer with mismatched endpoint") + } + + var foundAllowedIP bool + for _, aip := range peer.AllowedIPs { + if aip.String() == allowedIP { + foundAllowedIP = true + break + } + } + if !foundAllowedIP { + t.Fatal("configured peer with mismatched Allowed IPs") + } +} + +func Test_UpdatePeerEndpoint(t *testing.T) { + newEndpoint := "127.0.0.1:9999" + err := UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint) + if err != nil { + t.Fatal(err) + } + + peer, err := getPeer() + if err != nil { + t.Fatal(err) + } + + if peer.Endpoint.String() != newEndpoint { + t.Fatal("configured peer with mismatched endpoint") + } +} + +func Test_RemovePeer(t *testing.T) { + err := RemovePeer(ifaceName, peerPubKey) + if err != nil { + t.Fatal(err) + } + _, err = getPeer() + if err.Error() != "peer not found" { + t.Fatal(err) + } +} +func Test_Close(t *testing.T) { + err := Close() + if err != nil { + t.Fatal(err) + } +} +func getPeer() (wgtypes.Peer, error) { + emptyPeer := wgtypes.Peer{} + wg, err := wgctrl.New() + if err != nil { + return emptyPeer, err + } + defer wg.Close() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + return emptyPeer, err + } + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return emptyPeer, fmt.Errorf("peer not found") +} diff --git a/iface/iface_windows.go b/iface/iface_windows.go index 46966b23d..ddf279af8 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -39,3 +39,8 @@ func assignAddr(address string, ifaceName string) error { func getUAPI(iface string) (net.Listener, error) { return ipc.UAPIListen(iface) } + +// Closes the tunnel interface +func Close() error { + return CloseWithUserspace() +}