mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 03:11:02 +01:00
172 lines
3.8 KiB
Go
172 lines
3.8 KiB
Go
|
package relay
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pion/stun/v2"
|
||
|
"github.com/pion/turn/v3"
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
)
|
||
|
|
||
|
// ProbeResult holds the info about the result of a relay probe request
|
||
|
type ProbeResult struct {
|
||
|
URI *stun.URI
|
||
|
Err error
|
||
|
Addr string
|
||
|
}
|
||
|
|
||
|
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
|
||
|
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||
|
defer func() {
|
||
|
if probeErr != nil {
|
||
|
log.Debugf("stun probe error from %s: %s", uri, probeErr)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
client, err := stun.DialURI(uri, &stun.DialConfig{})
|
||
|
if err != nil {
|
||
|
probeErr = fmt.Errorf("dial: %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := client.Close(); err != nil && probeErr == nil {
|
||
|
probeErr = fmt.Errorf("close: %w", err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
done := make(chan struct{})
|
||
|
if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||
|
if res.Error != nil {
|
||
|
probeErr = fmt.Errorf("request: %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var xorAddr stun.XORMappedAddress
|
||
|
if getErr := xorAddr.GetFrom(res.Message); getErr != nil {
|
||
|
probeErr = fmt.Errorf("get xor addr: %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
log.Debugf("stun probe received address from %s: %s", uri, xorAddr)
|
||
|
addr = xorAddr.String()
|
||
|
|
||
|
done <- struct{}{}
|
||
|
}); err != nil {
|
||
|
probeErr = fmt.Errorf("client: %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
probeErr = fmt.Errorf("stun request: %w", ctx.Err())
|
||
|
return
|
||
|
case <-done:
|
||
|
}
|
||
|
|
||
|
return addr, nil
|
||
|
}
|
||
|
|
||
|
// ProbeTURN tries allocating a session from the given TURN URI
|
||
|
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||
|
defer func() {
|
||
|
if probeErr != nil {
|
||
|
log.Debugf("turn probe error from %s: %s", uri, probeErr)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
|
||
|
|
||
|
var conn net.PacketConn
|
||
|
switch uri.Proto {
|
||
|
case stun.ProtoTypeUDP:
|
||
|
var err error
|
||
|
conn, err = net.ListenPacket("udp", "")
|
||
|
if err != nil {
|
||
|
probeErr = fmt.Errorf("listen: %w", err)
|
||
|
return
|
||
|
}
|
||
|
case stun.ProtoTypeTCP:
|
||
|
dialer := net.Dialer{}
|
||
|
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
|
||
|
if err != nil {
|
||
|
probeErr = fmt.Errorf("dial: %w", err)
|
||
|
return
|
||
|
}
|
||
|
conn = turn.NewSTUNConn(tcpConn)
|
||
|
default:
|
||
|
probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := conn.Close(); err != nil && probeErr == nil {
|
||
|
probeErr = fmt.Errorf("conn close: %w", err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
cfg := &turn.ClientConfig{
|
||
|
STUNServerAddr: turnServerAddr,
|
||
|
TURNServerAddr: turnServerAddr,
|
||
|
Conn: conn,
|
||
|
Username: uri.Username,
|
||
|
Password: uri.Password,
|
||
|
}
|
||
|
client, err := turn.NewClient(cfg)
|
||
|
if err != nil {
|
||
|
probeErr = fmt.Errorf("create client: %w", err)
|
||
|
return
|
||
|
}
|
||
|
defer client.Close()
|
||
|
|
||
|
if err := client.Listen(); err != nil {
|
||
|
probeErr = fmt.Errorf("client listen: %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
relayConn, err := client.Allocate()
|
||
|
if err != nil {
|
||
|
probeErr = fmt.Errorf("allocate: %w", err)
|
||
|
return
|
||
|
}
|
||
|
defer func() {
|
||
|
if err := relayConn.Close(); err != nil && probeErr == nil {
|
||
|
probeErr = fmt.Errorf("close relay conn: %w", err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr())
|
||
|
|
||
|
return relayConn.LocalAddr().String(), nil
|
||
|
}
|
||
|
|
||
|
// ProbeAll probes all given servers asynchronously and returns the results
|
||
|
func ProbeAll(
|
||
|
ctx context.Context,
|
||
|
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
|
||
|
relays []*stun.URI,
|
||
|
) []ProbeResult {
|
||
|
results := make([]ProbeResult, len(relays))
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
for i, uri := range relays {
|
||
|
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||
|
defer cancel()
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func(res *ProbeResult, stunURI *stun.URI) {
|
||
|
defer wg.Done()
|
||
|
res.URI = stunURI
|
||
|
res.Addr, res.Err = fn(ctx, stunURI)
|
||
|
}(&results[i], uri)
|
||
|
}
|
||
|
|
||
|
wg.Wait()
|
||
|
|
||
|
return results
|
||
|
}
|