mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 11:33:48 +01:00
feat: first project flow
This commit is contained in:
commit
6b3ba0feaf
69
cmd/wiretrustee/cmd_start.go
Normal file
69
cmd/wiretrustee/cmd_start.go
Normal file
@ -0,0 +1,69 @@
|
||||
package wiretrustee
|
||||
|
||||
import (
|
||||
"context"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wiretrustee/wiretrustee/signal"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
ExitSetupFailed = 1
|
||||
)
|
||||
|
||||
func init() {
|
||||
runCmd := &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "start wiretrustee",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config, err := ReadConfig("config.yml")
|
||||
if err != nil {
|
||||
log.Fatal("failed to load config")
|
||||
os.Exit(ExitSetupFailed)
|
||||
}
|
||||
|
||||
//todo print config
|
||||
|
||||
//todo connect to signal
|
||||
ctx := context.Background()
|
||||
signalClient, err := signal.NewClient(config.SignalAddr, ctx)
|
||||
if err != nil {
|
||||
log.Errorf("error while connecting to the Signal Exchange Service %s: %s", config.SignalAddr, err)
|
||||
os.Exit(ExitSetupFailed)
|
||||
}
|
||||
//todo proper close handling
|
||||
defer func() { signalClient.Close() }()
|
||||
|
||||
signalClient.WaitConnected()
|
||||
|
||||
select {}
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(runCmd)
|
||||
}
|
||||
|
||||
func ReadConfig(path string) (*Config, error) {
|
||||
/*f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
bs, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
|
||||
err = yaml.Unmarshal(bs, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil*/
|
||||
|
||||
return &Config{}, nil
|
||||
}
|
12
cmd/wiretrustee/config.go
Normal file
12
cmd/wiretrustee/config.go
Normal file
@ -0,0 +1,12 @@
|
||||
package wiretrustee
|
||||
|
||||
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
type Config struct {
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey wgtypes.Key
|
||||
// configured remote peers (Wireguard public keys)
|
||||
Peers string
|
||||
// host:port of the signal server
|
||||
SignalAddr string
|
||||
}
|
248
engine/agent.go
Normal file
248
engine/agent.go
Normal file
@ -0,0 +1,248 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/signal"
|
||||
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"net"
|
||||
)
|
||||
|
||||
// PeerAgent is responsible for establishing and maintaining of the connection between two peers (local and remote)
|
||||
// It uses underlying ice.Agent and ice.Conn
|
||||
type PeerAgent struct {
|
||||
// a Wireguard public key of the peer
|
||||
LocalKey string
|
||||
// a Wireguard public key of the remote peer
|
||||
RemoteKey string
|
||||
// ICE iceAgent that actually negotiates and maintains peer-to-peer connection
|
||||
iceAgent *ice.Agent
|
||||
// Actual peer-to-peer connection
|
||||
conn *ice.Conn
|
||||
// a signal.Client to negotiate initial connection
|
||||
signal signal.Client
|
||||
// a connection to a local Wireguard instance to proxy data
|
||||
wgConn net.Conn
|
||||
// an address of local Wireguard instance
|
||||
wgAddr string
|
||||
}
|
||||
|
||||
// NewPeerAgent creates a new PeerAgent with give local and remote Wireguard public keys and initializes an ICE Agent
|
||||
func NewPeerAgent(localKey string, remoteKey string, stunTurnURLS []*ice.URL, wgAddr string) (*PeerAgent, error) {
|
||||
|
||||
// init ICE Agent
|
||||
iceAgent, err := ice.NewAgent(&ice.AgentConfig{
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4},
|
||||
Urls: stunTurnURLS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerAgent := &PeerAgent{
|
||||
LocalKey: localKey,
|
||||
RemoteKey: remoteKey,
|
||||
iceAgent: iceAgent,
|
||||
wgAddr: wgAddr,
|
||||
conn: nil,
|
||||
wgConn: nil,
|
||||
}
|
||||
|
||||
err = peerAgent.onConnectionStateChange()
|
||||
if err != nil {
|
||||
//todo close agent
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = peerAgent.onCandidate()
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on ICE connection state changes %s", err)
|
||||
//todo close agent
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peerAgent, nil
|
||||
|
||||
}
|
||||
|
||||
// proxyToRemotePeer proxies everything from Wireguard to the remote peer
|
||||
// blocks
|
||||
func (pa *PeerAgent) proxyToRemotePeer() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := pa.wgConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Warnln("Error reading from peer: ", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
n, err = pa.conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Warnln("Error writing to remote peer: ", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocalWireguard proxies everything from the remote peer to local Wireguard
|
||||
// blocks
|
||||
func (pa *PeerAgent) proxyToLocalWireguard() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := pa.conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed reading from remote connection %s", err)
|
||||
}
|
||||
|
||||
n, err = pa.wgConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed writing to local Wireguard instance %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// OpenConnection opens connection to remote peer. Flow:
|
||||
// 1. start gathering connection candidates
|
||||
// 2. if the peer was an initiator then it dials to the remote peer
|
||||
// 3. if the peer wasn't an initiator then it waits for incoming connection from the remote peer
|
||||
// 4. after connection has been established peer starts to:
|
||||
// - proxy all local Wireguard's packets to the remote peer
|
||||
// - proxy all incoming data from the remote peer to local Wireguard
|
||||
// The returned connection address can be used to be set as Wireguard's remote peer endpoint
|
||||
func (pa *PeerAgent) OpenConnection(initiator bool) (net.Conn, error) {
|
||||
// start gathering candidates
|
||||
err := pa.iceAgent.GatherCandidates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// by that time it should be already set
|
||||
frag, pwd, err := pa.iceAgent.GetRemoteUserCredentials()
|
||||
if err != nil {
|
||||
log.Errorf("remote credentials are not set for remote peer %s", pa.RemoteKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// initiate remote connection
|
||||
// will block until connection was established
|
||||
var conn *ice.Conn = nil
|
||||
if initiator {
|
||||
conn, err = pa.iceAgent.Dial(context.TODO(), frag, pwd)
|
||||
} else {
|
||||
conn, err = pa.iceAgent.Accept(context.TODO(), frag, pwd)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("failed listening on local port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("Local addr %s, remote addr %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
pa.conn = conn
|
||||
|
||||
// connect to local Wireguard instance
|
||||
wgConn, err := net.Dial("udp", pa.wgAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
pa.wgConn = wgConn
|
||||
|
||||
go func() {
|
||||
pa.proxyToRemotePeer()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
pa.proxyToLocalWireguard()
|
||||
}()
|
||||
|
||||
return wgConn, nil
|
||||
}
|
||||
|
||||
func (pa *PeerAgent) OnAnswer(msg *sProto.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pa *PeerAgent) OnRemoteCandidate(msg *sProto.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// signalCandidate sends a message with a local ice.Candidate details to the remote peer via signal server
|
||||
func (pa *PeerAgent) signalCandidate(c ice.Candidate) error {
|
||||
err := pa.signal.Send(&sProto.Message{
|
||||
Type: sProto.Message_CANDIDATE,
|
||||
Key: pa.LocalKey,
|
||||
RemoteKey: pa.RemoteKey,
|
||||
Body: c.Marshal(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// onCandidate detects new local ice.Candidate and sends it to the remote peer via signal server
|
||||
func (pa *PeerAgent) onCandidate() error {
|
||||
return pa.iceAgent.OnCandidate(func(candidate ice.Candidate) {
|
||||
if candidate != nil {
|
||||
err := pa.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", pa.RemoteKey, err)
|
||||
//todo ??
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// onConnectionStateChange listens on ice.Agent connection state change events and once connected checks a Candidate pair
|
||||
// the ice.Conn was established with
|
||||
// Mostly used for debugging purposes (e.g. connection time, etc)
|
||||
func (pa *PeerAgent) onConnectionStateChange() error {
|
||||
return pa.iceAgent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
log.Debugf("ICE Connection State has changed: %s", state.String())
|
||||
if state == ice.ConnectionStateConnected {
|
||||
// once the connection has been established we can check the selected candidate pair
|
||||
pair, err := pa.iceAgent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
log.Errorf("failed selecting active ICE candidate pair %s", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("connected to peer %s via selected candidate pair %s", pa.RemoteKey, pair)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// authenticate sets the signal.Credential of the remote peer
|
||||
// and sends local signal.Credential to teh remote peer via signal server
|
||||
func (pa *PeerAgent) Authenticate(credential *signal.Credential) error {
|
||||
|
||||
err := pa.iceAgent.SetRemoteCredentials(credential.UFrag, credential.Pwd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
localUFrag, localPwd, err := pa.iceAgent.GetLocalUserCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// notify the remote peer about our credentials
|
||||
answer := signal.MarshalCredential(pa.LocalKey, pa.RemoteKey, &signal.Credential{
|
||||
UFrag: localUFrag,
|
||||
Pwd: localPwd,
|
||||
}, sProto.Message_ANSWER)
|
||||
|
||||
//notify the remote peer of our credentials
|
||||
err = pa.signal.Send(answer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
126
engine/engine.go
Normal file
126
engine/engine.go
Normal file
@ -0,0 +1,126 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/iface"
|
||||
signal "github.com/wiretrustee/wiretrustee/signal"
|
||||
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Engine struct {
|
||||
// a list of STUN and TURN servers
|
||||
stunsTurns []*ice.URL
|
||||
// signal server client
|
||||
signal *signal.Client
|
||||
// peer agents indexed by local public key of the remote peers
|
||||
agents map[string]*PeerAgent
|
||||
// Wireguard interface
|
||||
wgIface string
|
||||
// Wireguard local address
|
||||
wgAddr string
|
||||
}
|
||||
|
||||
func NewEngine(signal *signal.Client, stunsTurns []*ice.URL) *Engine {
|
||||
return &Engine{
|
||||
stunsTurns: stunsTurns,
|
||||
signal: signal,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) Start(localKey string, peers []string) error {
|
||||
|
||||
// setup wireguard
|
||||
myKey, err := wgtypes.ParseKey(localKey)
|
||||
if err != nil {
|
||||
log.Errorf("error parsing Wireguard key %s: [%s]", localKey, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
err = iface.Create(e.wgIface, e.wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("error while creating interface %s: [%s]", e.wgIface, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
err = iface.Configure(e.wgIface, myKey.String())
|
||||
if err != nil {
|
||||
log.Errorf("error while configuring Wireguard interface [%s]: %s", e.wgIface, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
wgPort, err := iface.GetListenPort(e.wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("error while getting Wireguard interface port [%s]: %s", e.wgIface, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// initialize peer agents
|
||||
for _, peer := range peers {
|
||||
peerAgent, err := NewPeerAgent(localKey, peer, e.stunsTurns, fmt.Sprintf("127.0.0.1:%d", *wgPort))
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating peer agent for pair %s - %s", localKey, peer)
|
||||
return err
|
||||
}
|
||||
e.agents[localKey] = peerAgent
|
||||
}
|
||||
|
||||
e.receiveSignal(localKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) receiveSignal(localKey string) {
|
||||
// connect to a stream of messages coming from the signal server
|
||||
e.signal.Receive(localKey, func(msg *sProto.Message) error {
|
||||
|
||||
// check if this is our "buddy" peer
|
||||
peerAgent := e.agents[msg.Key]
|
||||
if peerAgent == nil {
|
||||
return fmt.Errorf("unknown peer %s", msg.Key)
|
||||
}
|
||||
|
||||
// the one who send offer (expects answer) is the initiator of teh connection
|
||||
initiator := msg.Type == sProto.Message_ANSWER
|
||||
|
||||
switch msg.Type {
|
||||
case sProto.Message_OFFER:
|
||||
case sProto.Message_ANSWER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = peerAgent.Authenticate(remoteCred)
|
||||
if err != nil {
|
||||
log.Errorf("error authenticating remote peer %s", msg.Key)
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := peerAgent.OpenConnection(initiator)
|
||||
if err != nil {
|
||||
log.Errorf("error opening connection ot remote peer %s", msg.Key)
|
||||
return err
|
||||
}
|
||||
|
||||
err = iface.UpdatePeer(e.wgIface, peerAgent.RemoteKey, "0.0.0.0/0", 15*time.Second, conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
log.Errorf("error while configuring Wireguard peer [%s] %s", peerAgent.RemoteKey, err.Error())
|
||||
return err
|
||||
}
|
||||
case sProto.Message_CANDIDATE:
|
||||
err := peerAgent.OnRemoteCandidate(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error handling CANDIATE from %s", msg.Key)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
e.signal.WaitConnected()
|
||||
}
|
21
go.mod
Normal file
21
go.mod
Normal file
@ -0,0 +1,21 @@
|
||||
module github.com/wiretrustee/wiretrustee
|
||||
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/golang/protobuf v1.4.2
|
||||
github.com/google/nftables v0.0.0-20201230142148-715e31cb3c31
|
||||
github.com/pion/ice/v2 v2.0.17
|
||||
github.com/pion/logging v0.2.2
|
||||
github.com/pion/stun v0.3.5
|
||||
github.com/pion/turn/v2 v2.0.5
|
||||
github.com/sirupsen/logrus v1.7.0
|
||||
github.com/spf13/cobra v1.1.3
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df
|
||||
github.com/wiretrustee/wiretrustee-signal v0.0.14
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
|
||||
golang.zx2c4.com/wireguard v0.0.20201118
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200609130330-bd2cb7843e1b
|
||||
google.golang.org/grpc v1.32.0
|
||||
)
|
82
iface/iface.go
Normal file
82
iface/iface.go
Normal file
@ -0,0 +1,82 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
//log "github.com/sirupsen/logrus"
|
||||
"errors"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMTU = 1280
|
||||
interfaceLimit = 10 // can be higher. Need to check different OS limits
|
||||
)
|
||||
|
||||
// Saves tun device object - is it required?
|
||||
var tunIface tun.Device
|
||||
|
||||
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func Create(iface string, address string) error {
|
||||
var err error
|
||||
|
||||
tunIface, err = tun.CreateTUN(iface, defaultMTU)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We need to create a wireguard-go device and listen to configuration requests
|
||||
tunDevice := device.NewDevice(tunIface, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
tunDevice.Up()
|
||||
tunSock, err := ipc.UAPIOpen(iface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uapi, err := ipc.UAPIListen(iface, tunSock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := uapi.Accept()
|
||||
if err != nil {
|
||||
log.Debugln(err)
|
||||
return
|
||||
}
|
||||
go tunDevice.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugln("UAPI listener started")
|
||||
|
||||
err = assignAddr(iface, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deletes an existing Wireguard interface
|
||||
func Delete() error {
|
||||
return tunIface.Close()
|
||||
}
|
||||
|
||||
// GetIfaceName loops through the OS' interfaceLimit and returns the first available interface name based on
|
||||
// interface prefixes and index
|
||||
func GetIfaceName() (string, error) {
|
||||
for i := 0; i < interfaceLimit; i++ {
|
||||
_, err := net.InterfaceByName(interfacePrefix + strconv.Itoa(i))
|
||||
if err != nil {
|
||||
if err.Error() != "no such network interface" {
|
||||
return interfacePrefix + strconv.Itoa(i), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "none", errors.New(fmt.Sprintf("Couldn't find an available interface index within the limit of: %d", interfaceLimit))
|
||||
}
|
38
iface/iface_darwin.go
Normal file
38
iface/iface_darwin.go
Normal file
@ -0,0 +1,38 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
interfacePrefix = "utun"
|
||||
)
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func assignAddr(iface string, address string) error {
|
||||
ip := strings.Split(address, "/")
|
||||
cmd := exec.Command("ifconfig", iface, "inet", address, ip[0])
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Infoln("Command: %v failed with output %s and error: ", cmd.String(), out)
|
||||
return err
|
||||
}
|
||||
_, resolvedNet, err := net.ParseCIDR(address)
|
||||
err = addRoute(iface, resolvedNet)
|
||||
if err != nil {
|
||||
log.Infoln("Adding route failed with error:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoute Adds network route based on the range provided
|
||||
func addRoute(iface string, ipNet *net.IPNet) error {
|
||||
cmd := exec.Command("route", "add", "-net", ipNet.String(), "-interface", iface)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Printf("Command: %v failed with output %s and error: ", cmd.String(), out)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
33
iface/iface_linux.go
Normal file
33
iface/iface_linux.go
Normal file
@ -0,0 +1,33 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
interfacePrefix = "wg"
|
||||
)
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func assignAddr(iface string, address string) error {
|
||||
attrs := netlink.NewLinkAttrs()
|
||||
attrs.Name = iface
|
||||
|
||||
link := wgLink{
|
||||
attrs: &attrs,
|
||||
}
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", address, iface)
|
||||
addr, _ := netlink.ParseAddr(address)
|
||||
err := netlink.AddrAdd(&link, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", iface, address)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// On linux, the link must be brought up
|
||||
err = netlink.LinkSetUp(&link)
|
||||
return err
|
||||
}
|
85
iface/nat_linux.go
Normal file
85
iface/nat_linux.go
Normal file
@ -0,0 +1,85 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netns"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// Configure routing and IP masquerading
|
||||
//todo more docs on what exactly happens here and why it is needed
|
||||
func ConfigureNAT(primaryIface string) error {
|
||||
log.Debugf("adding NAT / IP masquerading using nftables")
|
||||
ns, err := netns.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn := nftables.Conn{NetNS: int(ns)}
|
||||
|
||||
log.Debugf("flushing nftable rulesets")
|
||||
conn.FlushRuleset()
|
||||
|
||||
log.Debugf("setting up nftable rules for ip masquerading")
|
||||
|
||||
nat := conn.AddTable(&nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "nat",
|
||||
})
|
||||
|
||||
conn.AddChain(&nftables.Chain{
|
||||
Name: "prerouting",
|
||||
Table: nat,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
})
|
||||
|
||||
post := conn.AddChain(&nftables.Chain{
|
||||
Name: "postrouting",
|
||||
Table: nat,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource,
|
||||
})
|
||||
|
||||
conn.AddRule(&nftables.Rule{
|
||||
Table: nat,
|
||||
Chain: post,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(primaryIface),
|
||||
},
|
||||
&expr.Masq{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := conn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enables IP forwarding system property.
|
||||
// Mostly used when you setup one peer as a VPN server.
|
||||
func EnableIPForward() error {
|
||||
f := "/proc/sys/net/ipv4/ip_forward"
|
||||
|
||||
content, err := ioutil.ReadFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(content) == "0\n" {
|
||||
log.Info("enabling IP Forward")
|
||||
return ioutil.WriteFile(f, []byte("1"), 0600)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
210
iface/wgctl.go
Normal file
210
iface/wgctl.go
Normal file
@ -0,0 +1,210 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Extends the functionality of Configure(iface string, privateKey string) by generating a new Wireguard private key
|
||||
func ConfigureWithKeyGen(iface string) (*wgtypes.Key, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, Configure(iface, key.String())
|
||||
}
|
||||
|
||||
// Configures a Wireguard interface
|
||||
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
||||
func Configure(iface string, privateKey string) error {
|
||||
|
||||
log.Debugf("configuring Wireguard interface %s", iface)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fwmark := 0
|
||||
cfg := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ReplacePeers: false,
|
||||
FirewallMark: &fwmark,
|
||||
}
|
||||
err = wg.ConfigureDevice(iface, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetListenPort(iface string) (*int, error) {
|
||||
log.Debugf("getting Wireguard listen port of interface %s", iface)
|
||||
|
||||
//discover Wireguard current configuration
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
d, err := wg.Device(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("got Wireguard device listen port %s, %d", iface, &d.ListenPort)
|
||||
|
||||
return &d.ListenPort, nil
|
||||
}
|
||||
|
||||
// Updates a Wireguard interface listen port
|
||||
func UpdateListenPort(iface string, newPort int) error {
|
||||
log.Debugf("updating Wireguard listen port of interface %s, new port %d", iface, newPort)
|
||||
|
||||
//discover Wireguard current configuration
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
_, err = wg.Device(iface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("got Wireguard device %s", iface)
|
||||
|
||||
config := wgtypes.Config{
|
||||
ListenPort: &newPort,
|
||||
ReplacePeers: false,
|
||||
}
|
||||
err = wg.ConfigureDevice(iface, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("updated Wireguard listen port of interface %s, new port %d", iface, newPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
|
||||
// Updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||
// Endpoint is optional
|
||||
func UpdatePeer(iface string, peerKey string, allowedIps string, keepAlive time.Duration, endpoint string) error {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
_, err = wg.Device(iface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("got Wireguard device %s", iface)
|
||||
|
||||
//parse allowed ips
|
||||
ipNet, err := netlink.ParseIPNet(allowedIps)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
|
||||
peers := make([]wgtypes.PeerConfig, 0)
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
|
||||
config := wgtypes.Config{
|
||||
ReplacePeers: false,
|
||||
Peers: peers,
|
||||
}
|
||||
err = wg.ConfigureDevice(iface, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if endpoint != "" {
|
||||
return UpdatePeerEndpoint(iface, peerKey, endpoint)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Updates a Wireguard interface Peer with the new endpoint
|
||||
// Used when NAT hole punching was successful and an update of the remote peer endpoint is required
|
||||
func UpdatePeerEndpoint(iface string, peerKey string, newEndpoint string) error {
|
||||
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
_, err = wg.Device(iface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("got Wireguard device %s", iface)
|
||||
|
||||
peerAddr, err := net.ResolveUDPAddr("udp4", newEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("parsed peer endpoint [%s]", peerAddr.String())
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
peers := make([]wgtypes.PeerConfig, 0)
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
UpdateOnly: true,
|
||||
Endpoint: peerAddr,
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
|
||||
config := wgtypes.Config{
|
||||
ReplacePeers: false,
|
||||
Peers: peers,
|
||||
}
|
||||
err = wg.ConfigureDevice(iface, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type wgLink struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
}
|
||||
|
||||
func (w *wgLink) Attrs() *netlink.LinkAttrs {
|
||||
return w.attrs
|
||||
}
|
||||
|
||||
func (w *wgLink) Type() string {
|
||||
return "wireguard"
|
||||
}
|
317
nat/discovery.go
Normal file
317
nat/discovery.go
Normal file
@ -0,0 +1,317 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/pion/stun"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Most of the code of this file is taken from the https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour package
|
||||
// Copyright 2018 Pion LLC
|
||||
|
||||
const (
|
||||
messageHeaderSize = 20
|
||||
)
|
||||
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
var (
|
||||
errResponseMessage = errors.New("error reading from response message channel")
|
||||
errTimedOut = errors.New("timed out waiting for response")
|
||||
errNoOtherAddress = errors.New("no OTHER-ADDRESS in the STUN response message")
|
||||
)
|
||||
|
||||
type Discovery struct {
|
||||
stunAddr string
|
||||
// a STUN server connection timeout
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewDiscovery(stunAddr string, timeout time.Duration) *Discovery {
|
||||
return &Discovery{
|
||||
stunAddr: stunAddr,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Ip net.IP
|
||||
Port int
|
||||
// a type of the candidate [host, srflx, prflx, relay] - see WebRTC spec
|
||||
Type string
|
||||
}
|
||||
|
||||
type Behaviour struct {
|
||||
// indicates whether NAT is hard - address dependent or address and port dependent
|
||||
IsStrict bool
|
||||
// a list of external addresses (IP:port) received from the STUN server while testing NAT
|
||||
// these can be used for the Wireguard connection in case IsStrict = false
|
||||
Candidates []*Candidate
|
||||
|
||||
LocalPort int
|
||||
}
|
||||
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
type stunServerConn struct {
|
||||
conn net.PacketConn
|
||||
LocalAddr net.Addr
|
||||
RemoteAddr *net.UDPAddr
|
||||
OtherAddr *net.UDPAddr
|
||||
messageChan chan *stun.Message
|
||||
}
|
||||
|
||||
func (c *stunServerConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Discovers connection candidates and NAT behaviour by probing STUN server.
|
||||
// For proper NAT behaviour it is required for the The STUN server to have multiple IPs (for probing different destinations).
|
||||
// See https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour and https://tools.ietf.org/html/rfc5780 for details.
|
||||
// In case the returned Behaviour.IsStrict = false the Behaviour.LocalPort and any of the Probes can be used for the Wireguard communication
|
||||
// since the hole has been already punched.
|
||||
// When Behaviour.IsStrict = true the hole punching requires extra actions.
|
||||
func (d *Discovery) Discover() (*Behaviour, error) {
|
||||
|
||||
// get a local address (candidate)
|
||||
localConn, err := net.Dial("udp", "8.8.8.8:53")
|
||||
if err != nil {
|
||||
log.Errorf("Error getting local address: %s\n", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("Local address %s", localConn.LocalAddr().String())
|
||||
err = localConn.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveUDPAddr("udp4", localConn.LocalAddr().String())
|
||||
|
||||
mapTestConn, err := connect(d.stunAddr, lAddr)
|
||||
if err != nil {
|
||||
log.Errorf("Error creating STUN connection: %s\n", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer mapTestConn.Close()
|
||||
|
||||
var candidates = []*Candidate{{Ip: lAddr.IP, Port: lAddr.Port, Type: "host"}}
|
||||
|
||||
// Test I: Regular binding request
|
||||
log.Info("Mapping Test I: Regular binding request")
|
||||
request := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
|
||||
|
||||
resp, err := mapTestConn.roundTrip(request, mapTestConn.RemoteAddr, d.timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse response message for XOR-MAPPED-ADDRESS and make sure OTHER-ADDRESS valid
|
||||
resps1 := parse(resp)
|
||||
if resps1.xorAddr == nil || resps1.otherAddr == nil {
|
||||
log.Warn("Error: NAT discovery feature not supported by this STUN server")
|
||||
return nil, errNoOtherAddress
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
|
||||
if err != nil {
|
||||
log.Errorf("Failed resolving OTHER-ADDRESS: %v\n", resps1.otherAddr)
|
||||
return nil, err
|
||||
}
|
||||
mapTestConn.OtherAddr = addr
|
||||
log.Infof("Received XOR-MAPPED-ADDRESS: %v\n", resps1.xorAddr)
|
||||
|
||||
candidates = append(candidates, &Candidate{resps1.xorAddr.IP, resps1.xorAddr.Port, "srflx"})
|
||||
|
||||
// Assert mapping behavior
|
||||
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
|
||||
log.Info("=> NAT mapping behavior: endpoint independent (no NAT)")
|
||||
return &Behaviour{
|
||||
IsStrict: false,
|
||||
Candidates: candidates,
|
||||
LocalPort: mapTestConn.LocalAddr.(*net.UDPAddr).Port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test II: Send binding request to the other address but primary port
|
||||
log.Info("Mapping Test II: Send binding request to the other address but primary port")
|
||||
oaddr := *mapTestConn.OtherAddr
|
||||
oaddr.Port = mapTestConn.RemoteAddr.Port
|
||||
resp, err = mapTestConn.roundTrip(request, &oaddr, d.timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resps2 := parse(resp)
|
||||
candidates = append(candidates, &Candidate{resps2.xorAddr.IP, resps2.xorAddr.Port, "srflx"})
|
||||
log.Infof("Received XOR-MAPPED-ADDRESS: %v\n", resps2.xorAddr)
|
||||
|
||||
// Assert mapping behavior
|
||||
if resps2.xorAddr.String() == resps1.xorAddr.String() {
|
||||
log.Info("=> NAT mapping behavior: endpoint independent")
|
||||
return &Behaviour{
|
||||
IsStrict: false,
|
||||
Candidates: candidates,
|
||||
LocalPort: mapTestConn.LocalAddr.(*net.UDPAddr).Port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test III: Send binding request to the other address and port
|
||||
log.Info("Mapping Test III: Send binding request to the other address and port")
|
||||
resp, err = mapTestConn.roundTrip(request, mapTestConn.OtherAddr, d.timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resps3 := parse(resp)
|
||||
candidates = append(candidates, &Candidate{resps3.xorAddr.IP, resps3.xorAddr.Port, "srflx"})
|
||||
log.Infof("Received XOR-MAPPED-ADDRESS: %v\n", resps3.xorAddr)
|
||||
|
||||
// Assert mapping behavior
|
||||
if resps3.xorAddr.String() == resps2.xorAddr.String() {
|
||||
log.Info("=> NAT mapping behavior: address dependent")
|
||||
} else {
|
||||
log.Info("=> NAT mapping behavior: address and port dependent")
|
||||
}
|
||||
|
||||
return &Behaviour{
|
||||
IsStrict: true,
|
||||
Candidates: candidates,
|
||||
LocalPort: mapTestConn.LocalAddr.(*net.UDPAddr).Port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
func connect(stunAddr string, lAddr *net.UDPAddr) (*stunServerConn, error) {
|
||||
log.Debugf("connecting to STUN server: %s\n", stunAddr)
|
||||
addr, err := net.ResolveUDPAddr("udp4", stunAddr)
|
||||
if err != nil {
|
||||
log.Errorf("Error resolving address: %s\n", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := net.ListenUDP("udp4", lAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Local address: %s\n", c.LocalAddr())
|
||||
log.Debugf("Remote address: %s\n", addr.String())
|
||||
|
||||
mChan := listen(c)
|
||||
|
||||
return &stunServerConn{
|
||||
conn: c,
|
||||
LocalAddr: c.LocalAddr(),
|
||||
RemoteAddr: addr,
|
||||
messageChan: mChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
func listen(conn *net.UDPConn) (messages chan *stun.Message) {
|
||||
messages = make(chan *stun.Message)
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
n, addr, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
close(messages)
|
||||
return
|
||||
}
|
||||
log.Debugf("Response from %v: (%v bytes)\n", addr, n)
|
||||
buf = buf[:n]
|
||||
|
||||
m := new(stun.Message)
|
||||
m.Raw = buf
|
||||
err = m.Decode()
|
||||
if err != nil {
|
||||
log.Debugf("Error decoding message: %v\n", err)
|
||||
close(messages)
|
||||
return
|
||||
}
|
||||
|
||||
messages <- m
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// Send request and wait for response or timeout
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr, timeout time.Duration) (*stun.Message, error) {
|
||||
_ = msg.NewTransactionID()
|
||||
log.Debugf("Sending to %v: (%v bytes)\n", addr, msg.Length+messageHeaderSize)
|
||||
log.Debugf("%v\n", msg)
|
||||
for _, attr := range msg.Attributes {
|
||||
log.Debugf("\t%v (l=%v)\n", attr, attr.Length)
|
||||
}
|
||||
_, err := c.conn.WriteTo(msg.Raw, addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error sending request to %v\n", addr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for response or timeout
|
||||
select {
|
||||
case m, ok := <-c.messageChan:
|
||||
if !ok {
|
||||
return nil, errResponseMessage
|
||||
}
|
||||
return m, nil
|
||||
//todo configure timeout
|
||||
case <-time.After(timeout):
|
||||
log.Warnf("Timed out waiting for response from server %v\n", addr)
|
||||
return nil, errTimedOut
|
||||
}
|
||||
}
|
||||
|
||||
// Parse a STUN message
|
||||
//taken from https://github.com/pion/stun/tree/master/cmd/stun-nat-behaviour
|
||||
func parse(msg *stun.Message) (ret struct {
|
||||
xorAddr *stun.XORMappedAddress
|
||||
otherAddr *stun.OtherAddress
|
||||
//respOrigin *stun.ResponseOrigin
|
||||
mappedAddr *stun.MappedAddress
|
||||
software *stun.Software
|
||||
}) {
|
||||
ret.mappedAddr = &stun.MappedAddress{}
|
||||
ret.xorAddr = &stun.XORMappedAddress{}
|
||||
//ret.respOrigin = &stun.ResponseOrigin{}
|
||||
ret.otherAddr = &stun.OtherAddress{}
|
||||
ret.software = &stun.Software{}
|
||||
if ret.xorAddr.GetFrom(msg) != nil {
|
||||
ret.xorAddr = nil
|
||||
}
|
||||
if ret.otherAddr.GetFrom(msg) != nil {
|
||||
ret.otherAddr = nil
|
||||
}
|
||||
/*if ret.respOrigin.GetFrom(msg) != nil {
|
||||
ret.respOrigin = nil
|
||||
}*/
|
||||
if ret.mappedAddr.GetFrom(msg) != nil {
|
||||
ret.mappedAddr = nil
|
||||
}
|
||||
if ret.software.GetFrom(msg) != nil {
|
||||
ret.software = nil
|
||||
}
|
||||
log.Debugf("%v\n", msg)
|
||||
log.Debugf("\tMAPPED-ADDRESS: %v\n", ret.mappedAddr)
|
||||
log.Debugf("\tXOR-MAPPED-ADDRESS: %v\n", ret.xorAddr)
|
||||
//log.Debugf("\tRESPONSE-ORIGIN: %v\n", ret.respOrigin)
|
||||
log.Debugf("\tOTHER-ADDRESS: %v\n", ret.otherAddr)
|
||||
log.Debugf("\tSOFTWARE: %v\n", ret.software)
|
||||
for _, attr := range msg.Attributes {
|
||||
switch attr.Type {
|
||||
case
|
||||
stun.AttrXORMappedAddress,
|
||||
stun.AttrOtherAddress,
|
||||
//stun.AttrResponseOrigin,
|
||||
stun.AttrMappedAddress,
|
||||
stun.AttrSoftware:
|
||||
break //nolint: staticcheck
|
||||
default:
|
||||
log.Debugf("\t%v (l=%v)\n", attr, attr.Length)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
83
nat/nat.go
Normal file
83
nat/nat.go
Normal file
@ -0,0 +1,83 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pion/webrtc/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee-signal/proto"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A set of tools to punch a UDP hole in NAT
|
||||
|
||||
// Uses WebRTC to probe the Network and gather connection Candidates.
|
||||
// It is important to request this method with multiple STUN server URLs because NAT type can be detected out of the multiple Probes (candidates)
|
||||
func PunchHole(stuns []string) ([]*proto.Candidate, error) {
|
||||
log.Debugf("starting to punch a NAT hole...")
|
||||
|
||||
pConn, err := newPeerConnection(stuns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer pConn.Close()
|
||||
|
||||
var candidates []*proto.Candidate
|
||||
pConn.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
|
||||
log.Debugf("ICE Connection State has changed: %s\n", connectionState.String())
|
||||
})
|
||||
pConn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
log.Debugf("got new ICE candidate: %s", candidate)
|
||||
if candidate != nil {
|
||||
candidates = append(candidates, &proto.Candidate{
|
||||
//Address: fmt.Sprintf("%s:%d", candidate.Address, candidate.Port),
|
||||
Proto: candidate.Protocol.String(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Create an offer to send to the other process
|
||||
offer, err := pConn.CreateOffer(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Sets the LocalDescription, and starts our UDP listeners
|
||||
// Note: this will start the gathering of ICE candidates
|
||||
if err = pConn.SetLocalDescription(offer); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
gatherComplete := webrtc.GatheringCompletePromise(pConn)
|
||||
//wait for all the ICE candidates to be collected
|
||||
select {
|
||||
case <-gatherComplete:
|
||||
|
||||
log.Debugf("collected %d candidates", len(candidates))
|
||||
|
||||
return candidates, nil
|
||||
case <-time.After(time.Duration(10) * time.Second): //todo better timeout handling, or no timeout at all?
|
||||
return nil, errors.New(fmt.Sprintf("timeout of %v seconds reached while waiting for hole punching", 10))
|
||||
}
|
||||
}
|
||||
|
||||
func newPeerConnection(stuns []string) (*webrtc.PeerConnection, error) {
|
||||
|
||||
log.Debugf("creating new peer connection ...")
|
||||
config := webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: stuns,
|
||||
},
|
||||
},
|
||||
}
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
settingEngine.SetNetworkTypes([]webrtc.NetworkType{
|
||||
webrtc.NetworkTypeUDP4,
|
||||
})
|
||||
api := webrtc.NewAPI(
|
||||
webrtc.WithSettingEngine(settingEngine),
|
||||
)
|
||||
log.Debugf("created new peer connection")
|
||||
|
||||
return api.NewPeerConnection(config)
|
||||
}
|
159
relay/client.go
Normal file
159
relay/client.go
Normal file
@ -0,0 +1,159 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/turn/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
//Client has no doc yet
|
||||
type Client struct {
|
||||
TurnC *turn.Client
|
||||
// remote peer to reply to
|
||||
peerAddr net.Addr
|
||||
// local Wireguard connection
|
||||
localWgConn net.Conn
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.TurnC.Close()
|
||||
if c.localWgConn != nil {
|
||||
err := c.localWgConn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewClient(turnAddr string, user string, pwd string) (*Client, error) {
|
||||
// a local UDP proxy to forward Wireguard's packets to the relay server
|
||||
// This endpoint should be specified in the Peer's Wireguard config
|
||||
proxyConn, err := net.ListenPacket("udp4", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &turn.ClientConfig{
|
||||
STUNServerAddr: turnAddr,
|
||||
TURNServerAddr: turnAddr,
|
||||
Conn: proxyConn,
|
||||
Username: user,
|
||||
Password: pwd,
|
||||
Realm: "wiretrustee.com",
|
||||
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||
}
|
||||
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Both, client and peer needs to listen to Turn packets
|
||||
err = client.Listen()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("local address %s", proxyConn.LocalAddr().String())
|
||||
return &Client{
|
||||
TurnC: client,
|
||||
}, err
|
||||
}
|
||||
|
||||
// Start relaying packets:
|
||||
// Incoming traffic from the relay sent by the other peer will be forwarded to local Wireguard
|
||||
// Outgoing traffic from local Wireguard will be intercepted and forwarded back to relayed connection
|
||||
// returns a relayed address (turn) to be used on the other side (peer)
|
||||
func (c *Client) Start(remoteAddr string, wgPort int) (*net.UDPAddr, *net.UDPAddr, error) {
|
||||
|
||||
udpRemoteAddr, err := net.ResolveUDPAddr("udp", remoteAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Allocate a relay socket on the TURN server
|
||||
relayConn, err := c.TurnC.Allocate()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// create a connection to a local Wireguard port to forward traffic to
|
||||
c.localWgConn, err = net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", +wgPort))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
log.Infof("allocated a new relay address [%s]", relayConn.LocalAddr().String())
|
||||
|
||||
// read from relay and write to local Wireguard
|
||||
c.relayPeerToLocalDst(relayConn, c.localWgConn)
|
||||
// read from local Wireguard and write to relay
|
||||
c.relayLocalDstToPeer(c.localWgConn, relayConn)
|
||||
|
||||
// Punch a UDP hole for the relayConn by sending a data to the udpRemoteAddr.
|
||||
// This will trigger a TURN client to generate a permission request to the
|
||||
// TURN server. After this, packets from the IP address will be accepted by
|
||||
// the TURN server.
|
||||
_, err = relayConn.WriteTo([]byte("Hello"), udpRemoteAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Infof("Punched a hole on [%s:%s]", udpRemoteAddr.IP, udpRemoteAddr.Port)
|
||||
|
||||
relayAddr, err := net.ResolveUDPAddr("udp", relayConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
wgAddr, err := net.ResolveUDPAddr("udp", c.localWgConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return relayAddr, wgAddr, nil
|
||||
}
|
||||
|
||||
func (c *Client) relayPeerToLocalDst(relayConn net.PacketConn, localConn net.Conn) {
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
var n int
|
||||
var err error
|
||||
for {
|
||||
n, c.peerAddr, err = relayConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
// log.Warnln("Error reading from peer: ", err.Error())
|
||||
continue
|
||||
}
|
||||
n, err = localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Warnln("Error writing to local destination: ", err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) relayLocalDstToPeer(localConn net.Conn, relayConn net.PacketConn) {
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
var n int
|
||||
var err error
|
||||
for {
|
||||
n, err = localConn.Read(buf)
|
||||
if err != nil {
|
||||
// log.Warnln("Error reading from local destination: ", err.Error())
|
||||
continue
|
||||
}
|
||||
if c.peerAddr == nil {
|
||||
log.Warnln("We didn't received any peer connection yet")
|
||||
continue
|
||||
}
|
||||
// log.Infoln("Received message from Local: ", string(buf[:n]))
|
||||
_, err = relayConn.WriteTo(buf[:n], c.peerAddr)
|
||||
if err != nil {
|
||||
log.Warnln("Error writing to peer: ", err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
18
signal/README.md
Normal file
18
signal/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
This is a Wiretrustee signal-exchange server and client library to exchange connection information between Wiretrustee Trusted Device and Wiretrustee Hub
|
||||
|
||||
The project uses gRPC library and defines service in protobuf file located in:
|
||||
```proto/signal_exchange.proto```
|
||||
|
||||
To build the project you have to do the following things.
|
||||
|
||||
Install protobuf version 3 (by default v3 is installed on ubuntu 20.04. On previous versions it is proto 2):
|
||||
```
|
||||
sudo apt install protoc-gen-go
|
||||
sudo apt install golang-goprotobuf-dev
|
||||
```
|
||||
|
||||
Generate gRPC code:
|
||||
```
|
||||
protoc -I proto/ proto/signalexchange.proto --go_out=plugins=grpc:proto
|
||||
|
||||
```
|
179
signal/client.go
Normal file
179
signal/client.go
Normal file
@ -0,0 +1,179 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
||||
|
||||
// Wraps the Signal Exchange Service gRpc client
|
||||
type Client struct {
|
||||
realClient proto.SignalExchangeClient
|
||||
signalConn *grpc.ClientConn
|
||||
ctx context.Context
|
||||
stream proto.SignalExchange_ConnectStreamClient
|
||||
//waiting group to notify once stream is connected
|
||||
connWg sync.WaitGroup //todo use a channel instead??
|
||||
}
|
||||
|
||||
// Closes underlying connections to the Signal Exchange
|
||||
func (client *Client) Close() error {
|
||||
return client.signalConn.Close()
|
||||
}
|
||||
|
||||
func NewClient(addr string, ctx context.Context) (*Client, error) {
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithInsecure(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to the signalling server %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connects to the Signal Exchange message stream and starts receiving messages.
|
||||
// The messages will be handled by msgHandler function provided.
|
||||
// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||
// The key is the identifier of our Peer (could be Wireguard public key)
|
||||
func (client *Client) Receive(key string, msgHandler func(msg *proto.Message) error) {
|
||||
client.connWg.Add(1)
|
||||
go func() {
|
||||
err := Retry(15, time.Second, func() error {
|
||||
return client.connect(key, msgHandler)
|
||||
}, func(err error) {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error %s. Retrying ... ", err)
|
||||
client.connWg.Add(1)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while communicating with the Signal Exchange %s ", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (client *Client) connect(key string, msgHandler func(msg *proto.Message) error) error {
|
||||
client.stream = nil
|
||||
|
||||
// add key fingerprint to the request header to be identified on the server side
|
||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||
ctx := metadata.NewOutgoingContext(client.ctx, md)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.realClient.ConnectStream(ctx)
|
||||
|
||||
client.stream = stream
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//connection established we are good to go
|
||||
client.connWg.Done()
|
||||
|
||||
log.Infof("connected to the Signal Exchange Stream")
|
||||
|
||||
return client.receive(stream, msgHandler)
|
||||
}
|
||||
|
||||
// Waits until the client is connected to the message stream
|
||||
func (client *Client) WaitConnected() {
|
||||
client.connWg.Wait()
|
||||
}
|
||||
|
||||
// Sends a message to the remote Peer through the Signal Exchange.
|
||||
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
||||
// Client.connWg can be used to wait
|
||||
func (client *Client) Send(msg *proto.Message) error {
|
||||
if client.stream == nil {
|
||||
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
|
||||
}
|
||||
|
||||
err := client.stream.Send(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receives messages from other peers coming through the Signal Exchange
|
||||
func (client *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
log.Warnf("stream canceled (usually indicates shutdown)")
|
||||
return err
|
||||
} else if s.Code() == codes.Unavailable {
|
||||
log.Warnf("server has been stopped")
|
||||
return err
|
||||
} else if err == io.EOF {
|
||||
log.Warnf("stream closed by server")
|
||||
return err
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("received a new message from Peer [fingerprint: %s] [type %s]", msg.Key, msg.Type)
|
||||
|
||||
//todo decrypt
|
||||
err = msgHandler(msg)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while handling message of Peer [fingerprint: %s] error: [%s]", msg.Key, err.Error())
|
||||
//todo send something??
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
credential := strings.Split(msg.Body, ":")
|
||||
if len(credential) != 2 {
|
||||
return nil, fmt.Errorf("error parsing message body %s", msg.Body)
|
||||
}
|
||||
return &Credential{
|
||||
UFrag: credential[0],
|
||||
Pwd: credential[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MarshalCredential(ourKey string, remoteKey string, credential *Credential, t proto.Message_Type) *proto.Message {
|
||||
return &proto.Message{
|
||||
Type: t,
|
||||
Key: ourKey,
|
||||
RemoteKey: remoteKey,
|
||||
Body: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
|
||||
}
|
||||
}
|
||||
|
||||
type Credential struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
}
|
52
signal/encryption.go
Normal file
52
signal/encryption.go
Normal file
@ -0,0 +1,52 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// As set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service.
|
||||
// We want to make sure that the Connection Candidates and other irrelevant (to the Signal Exchange) information can't be read anywhere else but the Peer the message is being sent to.
|
||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||
// Wireguard keys are used for encryption
|
||||
|
||||
// Encrypts a message using local Wireguard private key and remote peer's public key.
|
||||
func EncryptMessage(msg []byte, privateKey wgtypes.Key, remotePubKey wgtypes.Key) ([]byte, error) {
|
||||
nonce, err := genNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return box.Seal(nil, msg, nonce, toByte32(remotePubKey), toByte32(privateKey)), nil
|
||||
}
|
||||
|
||||
// Decrypts a message that has been encrypted by the remote peer using Wireguard private key and remote peer's public key.
|
||||
func DecryptMessage(encryptedMsg []byte, privateKey wgtypes.Key, remotePubKey wgtypes.Key) ([]byte, error) {
|
||||
nonce, err := genNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opened, ok := box.Open(nil, encryptedMsg, nonce, toByte32(remotePubKey), toByte32(privateKey))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to decrypt message from peer %s", remotePubKey.String())
|
||||
}
|
||||
|
||||
return opened, nil
|
||||
}
|
||||
|
||||
// Generates nonce of size 24
|
||||
func genNonce() (*[24]byte, error) {
|
||||
var nonce [24]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &nonce, nil
|
||||
}
|
||||
|
||||
// Converts Wireguard key to byte array of size 32 (a format used by the golang crypto package)
|
||||
func toByte32(key wgtypes.Key) *[32]byte {
|
||||
return (*[32]byte)(&key)
|
||||
}
|
18
signal/fingerprint.go
Normal file
18
signal/fingerprint.go
Normal file
@ -0,0 +1,18 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
const (
|
||||
HexTable = "0123456789abcdef"
|
||||
)
|
||||
|
||||
// Generates a SHA256 Fingerprint of the string
|
||||
func FingerPrint(key string) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(key))
|
||||
sha := hasher.Sum(nil)
|
||||
return hex.EncodeToString(sha)
|
||||
}
|
54
signal/peer/peer.go
Normal file
54
signal/peer/peer.go
Normal file
@ -0,0 +1,54 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
)
|
||||
|
||||
// Representation of a connected Peer
|
||||
type Peer struct {
|
||||
// a unique id of the Peer (e.g. sha256 fingerprint of the Wireguard public key)
|
||||
Id string
|
||||
|
||||
//a gRpc connection stream to the Peer
|
||||
Stream proto.SignalExchange_ConnectStreamServer
|
||||
}
|
||||
|
||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
|
||||
return &Peer{
|
||||
Id: id,
|
||||
Stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
// registry that holds all currently connected Peers
|
||||
type Registry struct {
|
||||
// Peer.key -> Peer
|
||||
Peers map[string]*Peer
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
Peers: make(map[string]*Peer),
|
||||
}
|
||||
}
|
||||
|
||||
// Registers peer in the registry
|
||||
func (reg *Registry) Register(peer *Peer) {
|
||||
if _, exists := reg.Peers[peer.Id]; exists {
|
||||
log.Warnf("peer [%s] has been already registered", peer.Id)
|
||||
} else {
|
||||
log.Printf("registering new peer [%s]", peer.Id)
|
||||
}
|
||||
//replace Peer even if exists
|
||||
//todo should we really replace?
|
||||
reg.Peers[peer.Id] = peer
|
||||
}
|
||||
|
||||
// Deregister Peer from the Registry (usually once it disconnects)
|
||||
func (reg *Registry) DeregisterHub(peer *Peer) {
|
||||
if _, ok := reg.Peers[peer.Id]; ok {
|
||||
delete(reg.Peers, peer.Id)
|
||||
log.Printf("deregistered peer [%s]", peer.Id)
|
||||
}
|
||||
}
|
4
signal/proto/constants.go
Normal file
4
signal/proto/constants.go
Normal file
@ -0,0 +1,4 @@
|
||||
package proto
|
||||
|
||||
// protocol constants, field names that can be used by both client and server
|
||||
const HeaderId = "x-wiretrustee-peer-id"
|
301
signal/proto/signalexchange.pb.go
Normal file
301
signal/proto/signalexchange.pb.go
Normal file
@ -0,0 +1,301 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: signalexchange.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
|
||||
|
||||
// Message type
|
||||
type Message_Type int32
|
||||
|
||||
const (
|
||||
Message_OFFER Message_Type = 0
|
||||
Message_ANSWER Message_Type = 1
|
||||
Message_CANDIDATE Message_Type = 2
|
||||
)
|
||||
|
||||
var Message_Type_name = map[int32]string{
|
||||
0: "OFFER",
|
||||
1: "ANSWER",
|
||||
2: "CANDIDATE",
|
||||
}
|
||||
|
||||
var Message_Type_value = map[string]int32{
|
||||
"OFFER": 0,
|
||||
"ANSWER": 1,
|
||||
"CANDIDATE": 2,
|
||||
}
|
||||
|
||||
func (x Message_Type) String() string {
|
||||
return proto.EnumName(Message_Type_name, int32(x))
|
||||
}
|
||||
|
||||
func (Message_Type) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_bf680d70b8e3473f, []int{0, 0}
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Type Message_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Message_Type" json:"type,omitempty"`
|
||||
// a sha256 fingerprint of the Wireguard public key
|
||||
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
|
||||
// a sha256 fingerprint of the Wireguard public key of the remote peer to connect to
|
||||
RemoteKey string `protobuf:"bytes,3,opt,name=remoteKey,proto3" json:"remoteKey,omitempty"`
|
||||
Body string `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Message) Reset() { *m = Message{} }
|
||||
func (m *Message) String() string { return proto.CompactTextString(m) }
|
||||
func (*Message) ProtoMessage() {}
|
||||
func (*Message) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_bf680d70b8e3473f, []int{0}
|
||||
}
|
||||
|
||||
func (m *Message) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Message.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Message) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Message.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *Message) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Message.Merge(m, src)
|
||||
}
|
||||
func (m *Message) XXX_Size() int {
|
||||
return xxx_messageInfo_Message.Size(m)
|
||||
}
|
||||
func (m *Message) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Message.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Message proto.InternalMessageInfo
|
||||
|
||||
func (m *Message) GetType() Message_Type {
|
||||
if m != nil {
|
||||
return m.Type
|
||||
}
|
||||
return Message_OFFER
|
||||
}
|
||||
|
||||
func (m *Message) GetKey() string {
|
||||
if m != nil {
|
||||
return m.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Message) GetRemoteKey() string {
|
||||
if m != nil {
|
||||
return m.RemoteKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Message) GetBody() string {
|
||||
if m != nil {
|
||||
return m.Body
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("signalexchange.Message_Type", Message_Type_name, Message_Type_value)
|
||||
proto.RegisterType((*Message)(nil), "signalexchange.Message")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("signalexchange.proto", fileDescriptor_bf680d70b8e3473f) }
|
||||
|
||||
var fileDescriptor_bf680d70b8e3473f = []byte{
|
||||
// 272 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x50, 0xcd, 0x4a, 0xf3, 0x40,
|
||||
0x14, 0xed, 0xb4, 0xf9, 0x1a, 0x72, 0xa1, 0x21, 0x5c, 0x3e, 0x30, 0x94, 0x2e, 0x42, 0x56, 0x59,
|
||||
0x48, 0x5a, 0xea, 0x52, 0x5c, 0xc4, 0x36, 0x15, 0x11, 0x2b, 0x24, 0x05, 0xc1, 0x5d, 0x9a, 0x5e,
|
||||
0xc7, 0x62, 0x9b, 0x09, 0x93, 0x11, 0x9c, 0x37, 0xf1, 0x25, 0x7c, 0x47, 0xe9, 0x34, 0x20, 0x0a,
|
||||
0xdd, 0xb8, 0x9a, 0xc3, 0xf9, 0x9b, 0xc3, 0x85, 0xff, 0xcd, 0x96, 0x57, 0xc5, 0x8e, 0xde, 0xcb,
|
||||
0x97, 0xa2, 0xe2, 0x14, 0xd7, 0x52, 0x28, 0x81, 0xee, 0x4f, 0x76, 0x18, 0x70, 0x21, 0xf8, 0x8e,
|
||||
0xc6, 0x46, 0x5d, 0xbf, 0x3d, 0x8f, 0x37, 0xd4, 0x94, 0x72, 0x5b, 0x2b, 0x21, 0x8f, 0x89, 0xf0,
|
||||
0x93, 0x81, 0x7d, 0x4f, 0x4d, 0x53, 0x70, 0xc2, 0x09, 0x58, 0x4a, 0xd7, 0xe4, 0xb3, 0x80, 0x45,
|
||||
0xee, 0x74, 0x14, 0xff, 0xfa, 0xa2, 0xb5, 0xc5, 0x2b, 0x5d, 0x53, 0x66, 0x9c, 0xe8, 0x41, 0xef,
|
||||
0x95, 0xb4, 0xdf, 0x0d, 0x58, 0xe4, 0x64, 0x07, 0x88, 0x23, 0x70, 0x24, 0xed, 0x85, 0xa2, 0x3b,
|
||||
0xd2, 0x7e, 0xcf, 0xf0, 0xdf, 0x04, 0x22, 0x58, 0x6b, 0xb1, 0xd1, 0xbe, 0x65, 0x04, 0x83, 0xc3,
|
||||
0x73, 0xb0, 0x0e, 0x8d, 0xe8, 0xc0, 0xbf, 0x87, 0xc5, 0x22, 0xcd, 0xbc, 0x0e, 0x02, 0xf4, 0x93,
|
||||
0x65, 0xfe, 0x98, 0x66, 0x1e, 0xc3, 0x01, 0x38, 0xb3, 0x64, 0x39, 0xbf, 0x9d, 0x27, 0xab, 0xd4,
|
||||
0xeb, 0x4e, 0x3f, 0x18, 0xb8, 0xb9, 0xd9, 0x95, 0xb6, 0xbb, 0xf0, 0x0a, 0xec, 0x99, 0xa8, 0x2a,
|
||||
0x2a, 0x15, 0x9e, 0x9d, 0xd8, 0x3c, 0x3c, 0x25, 0x84, 0x1d, 0xbc, 0x81, 0x41, 0x1b, 0xcf, 0x95,
|
||||
0xa4, 0x62, 0xff, 0x97, 0x92, 0x88, 0x4d, 0xd8, 0xb5, 0xf3, 0x64, 0xc7, 0x97, 0xc7, 0x4b, 0xf7,
|
||||
0xcd, 0x73, 0xf1, 0x15, 0x00, 0x00, 0xff, 0xff, 0x20, 0x13, 0xc1, 0xe1, 0xa6, 0x01, 0x00, 0x00,
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// SignalExchangeClient is the client API for SignalExchange service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type SignalExchangeClient interface {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
Connect(ctx context.Context, in *Message, opts ...grpc.CallOption) (*Message, error)
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error)
|
||||
}
|
||||
|
||||
type signalExchangeClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewSignalExchangeClient(cc *grpc.ClientConn) SignalExchangeClient {
|
||||
return &signalExchangeClient{cc}
|
||||
}
|
||||
|
||||
func (c *signalExchangeClient) Connect(ctx context.Context, in *Message, opts ...grpc.CallOption) (*Message, error) {
|
||||
out := new(Message)
|
||||
err := c.cc.Invoke(ctx, "/signalexchange.SignalExchange/Connect", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *signalExchangeClient) ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &_SignalExchange_serviceDesc.Streams[0], "/signalexchange.SignalExchange/ConnectStream", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &signalExchangeConnectStreamClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type SignalExchange_ConnectStreamClient interface {
|
||||
Send(*Message) error
|
||||
Recv() (*Message, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type signalExchangeConnectStreamClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamClient) Send(m *Message) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamClient) Recv() (*Message, error) {
|
||||
m := new(Message)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SignalExchangeServer is the server API for SignalExchange service.
|
||||
type SignalExchangeServer interface {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
Connect(context.Context, *Message) (*Message, error)
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
ConnectStream(SignalExchange_ConnectStreamServer) error
|
||||
}
|
||||
|
||||
// UnimplementedSignalExchangeServer can be embedded to have forward compatible implementations.
|
||||
type UnimplementedSignalExchangeServer struct {
|
||||
}
|
||||
|
||||
func (*UnimplementedSignalExchangeServer) Connect(ctx context.Context, req *Message) (*Message, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Connect not implemented")
|
||||
}
|
||||
func (*UnimplementedSignalExchangeServer) ConnectStream(srv SignalExchange_ConnectStreamServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method ConnectStream not implemented")
|
||||
}
|
||||
|
||||
func RegisterSignalExchangeServer(s *grpc.Server, srv SignalExchangeServer) {
|
||||
s.RegisterService(&_SignalExchange_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _SignalExchange_Connect_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Message)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SignalExchangeServer).Connect(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/signalexchange.SignalExchange/Connect",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SignalExchangeServer).Connect(ctx, req.(*Message))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SignalExchange_ConnectStream_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(SignalExchangeServer).ConnectStream(&signalExchangeConnectStreamServer{stream})
|
||||
}
|
||||
|
||||
type SignalExchange_ConnectStreamServer interface {
|
||||
Send(*Message) error
|
||||
Recv() (*Message, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type signalExchangeConnectStreamServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamServer) Send(m *Message) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamServer) Recv() (*Message, error) {
|
||||
m := new(Message)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var _SignalExchange_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "signalexchange.SignalExchange",
|
||||
HandlerType: (*SignalExchangeServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Connect",
|
||||
Handler: _SignalExchange_Connect_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "ConnectStream",
|
||||
Handler: _SignalExchange_ConnectStream_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "signalexchange.proto",
|
||||
}
|
34
signal/proto/signalexchange.proto
Normal file
34
signal/proto/signalexchange.proto
Normal file
@ -0,0 +1,34 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/descriptor.proto";
|
||||
|
||||
option go_package = ".;proto";
|
||||
|
||||
package signalexchange;
|
||||
|
||||
service SignalExchange {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
rpc Connect(Message) returns (Message) {}
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
rpc ConnectStream(stream Message) returns (stream Message) {}
|
||||
}
|
||||
|
||||
message Message {
|
||||
|
||||
// Message type
|
||||
enum Type {
|
||||
OFFER = 0;
|
||||
ANSWER = 1;
|
||||
CANDIDATE = 2;
|
||||
}
|
||||
|
||||
Type type = 1;
|
||||
|
||||
// a sha256 fingerprint of the Wireguard public key
|
||||
string key = 2;
|
||||
|
||||
// a sha256 fingerprint of the Wireguard public key of the remote peer to connect to
|
||||
string remoteKey = 3;
|
||||
|
||||
string body = 4;
|
||||
}
|
32
signal/retry.go
Normal file
32
signal/retry.go
Normal file
@ -0,0 +1,32 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Retries a given toExec function calling onError on failed attempts
|
||||
// onError shouldn be a lightweight function and shouldn't be blocking
|
||||
func Retry(attempts int, sleep time.Duration, toExec func() error, onError func(e error)) error {
|
||||
if err := toExec(); err != nil {
|
||||
if s, ok := err.(stop); ok {
|
||||
return s.error
|
||||
}
|
||||
|
||||
if attempts--; attempts > 0 {
|
||||
jitter := time.Duration(rand.Int63n(int64(sleep)))
|
||||
sleep = sleep + jitter/2
|
||||
|
||||
onError(err)
|
||||
time.Sleep(sleep)
|
||||
return Retry(attempts, 2*sleep, toExec, onError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type stop struct {
|
||||
error
|
||||
}
|
113
signal/signal.go
Normal file
113
signal/signal.go
Normal file
@ -0,0 +1,113 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/signal/peer"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
port = flag.Int("port", 10000, "The server port")
|
||||
)
|
||||
|
||||
type SignalExchangeServer struct {
|
||||
registry *peer.Registry
|
||||
}
|
||||
|
||||
func NewServer() *SignalExchangeServer {
|
||||
return &SignalExchangeServer{
|
||||
registry: peer.NewRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SignalExchangeServer) Connect(context.Context, *proto.Message) (*proto.Message, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Receive not implemented")
|
||||
}
|
||||
|
||||
func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
|
||||
p, err := s.connectPeer(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("peer [%s] has successfully connected", p.Id)
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
||||
// lookup the target peer where the message is going to
|
||||
if dstPeer, found := s.registry.Peers[msg.RemoteKey]; found {
|
||||
//forward the message to the target peer
|
||||
err := dstPeer.Stream.Send(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error while forwarding message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
||||
//todo respond to the sender?
|
||||
}
|
||||
} else {
|
||||
log.Warnf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
|
||||
//todo respond to the sender?
|
||||
}
|
||||
|
||||
}
|
||||
<-stream.Context().Done()
|
||||
return stream.Context().Err()
|
||||
}
|
||||
|
||||
func copyMessage(msg *proto.Message) *proto.Message {
|
||||
return &proto.Message{
|
||||
Type: msg.Type,
|
||||
Key: msg.Key,
|
||||
RemoteKey: msg.RemoteKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Handles initial Peer connection.
|
||||
// Each connection must provide an ID header.
|
||||
// At this moment the connecting Peer will be registered in the peer.Registry
|
||||
func (s SignalExchangeServer) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
|
||||
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
|
||||
if id, found := meta[proto.HeaderId]; found {
|
||||
p := peer.NewPeer(id[0], stream)
|
||||
s.registry.Register(p)
|
||||
return p, nil
|
||||
} else {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId)
|
||||
}
|
||||
} else {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
var opts []grpc.ServerOption
|
||||
grpcServer := grpc.NewServer(opts...)
|
||||
proto.RegisterSignalExchangeServer(grpcServer, NewServer())
|
||||
log.Printf("started server: localhost:%v", *port)
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user