mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-22 05:49:12 +01:00
188 lines
4.2 KiB
Go
188 lines
4.2 KiB
Go
package relay
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/pion/stun/v2"
|
|
"github.com/pion/turn/v3"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
)
|
|
|
|
// ProbeResult holds the info about the result of a relay probe request
|
|
type ProbeResult struct {
|
|
URI string
|
|
Err error
|
|
Addr string
|
|
}
|
|
|
|
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
|
|
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
|
defer func() {
|
|
if probeErr != nil {
|
|
log.Debugf("stun probe error from %s: %s", uri, probeErr)
|
|
}
|
|
}()
|
|
|
|
net, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("new net: %w", err)
|
|
return
|
|
}
|
|
|
|
client, err := stun.DialURI(uri, &stun.DialConfig{
|
|
Net: net,
|
|
})
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("dial: %w", err)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if err := client.Close(); err != nil && probeErr == nil {
|
|
probeErr = fmt.Errorf("close: %w", err)
|
|
}
|
|
}()
|
|
|
|
done := make(chan struct{})
|
|
if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
|
if res.Error != nil {
|
|
probeErr = fmt.Errorf("request: %w", err)
|
|
return
|
|
}
|
|
|
|
var xorAddr stun.XORMappedAddress
|
|
if getErr := xorAddr.GetFrom(res.Message); getErr != nil {
|
|
probeErr = fmt.Errorf("get xor addr: %w", err)
|
|
return
|
|
}
|
|
|
|
log.Debugf("stun probe received address from %s: %s", uri, xorAddr)
|
|
addr = xorAddr.String()
|
|
|
|
done <- struct{}{}
|
|
}); err != nil {
|
|
probeErr = fmt.Errorf("client: %w", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
probeErr = fmt.Errorf("stun request: %w", ctx.Err())
|
|
return
|
|
case <-done:
|
|
}
|
|
|
|
return addr, nil
|
|
}
|
|
|
|
// ProbeTURN tries allocating a session from the given TURN URI
|
|
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
|
defer func() {
|
|
if probeErr != nil {
|
|
log.Debugf("turn probe error from %s: %s", uri, probeErr)
|
|
}
|
|
}()
|
|
|
|
turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
|
|
|
|
var conn net.PacketConn
|
|
switch uri.Proto {
|
|
case stun.ProtoTypeUDP:
|
|
var err error
|
|
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("listen: %w", err)
|
|
return
|
|
}
|
|
case stun.ProtoTypeTCP:
|
|
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("dial: %w", err)
|
|
return
|
|
}
|
|
conn = turn.NewSTUNConn(tcpConn)
|
|
default:
|
|
probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if err := conn.Close(); err != nil && probeErr == nil {
|
|
probeErr = fmt.Errorf("conn close: %w", err)
|
|
}
|
|
}()
|
|
|
|
net, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("new net: %w", err)
|
|
return
|
|
}
|
|
cfg := &turn.ClientConfig{
|
|
STUNServerAddr: turnServerAddr,
|
|
TURNServerAddr: turnServerAddr,
|
|
Conn: conn,
|
|
Username: uri.Username,
|
|
Password: uri.Password,
|
|
Net: net,
|
|
}
|
|
client, err := turn.NewClient(cfg)
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("create client: %w", err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
if err := client.Listen(); err != nil {
|
|
probeErr = fmt.Errorf("client listen: %w", err)
|
|
return
|
|
}
|
|
|
|
relayConn, err := client.Allocate()
|
|
if err != nil {
|
|
probeErr = fmt.Errorf("allocate: %w", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := relayConn.Close(); err != nil && probeErr == nil {
|
|
probeErr = fmt.Errorf("close relay conn: %w", err)
|
|
}
|
|
}()
|
|
|
|
log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr())
|
|
|
|
return relayConn.LocalAddr().String(), nil
|
|
}
|
|
|
|
// ProbeAll probes all given servers asynchronously and returns the results
|
|
func ProbeAll(
|
|
ctx context.Context,
|
|
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
|
|
relays []*stun.URI,
|
|
) []ProbeResult {
|
|
results := make([]ProbeResult, len(relays))
|
|
|
|
var wg sync.WaitGroup
|
|
for i, uri := range relays {
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
|
|
wg.Add(1)
|
|
go func(res *ProbeResult, stunURI *stun.URI) {
|
|
defer wg.Done()
|
|
res.URI = stunURI.String()
|
|
res.Addr, res.Err = fn(ctx, stunURI)
|
|
}(&results[i], uri)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return results
|
|
}
|