mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 03:16:58 +02:00
[client,relay] Add QUIC support (#2962)
This commit is contained in:
@@ -10,6 +10,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
@@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
|
||||
msg.Free()
|
||||
default:
|
||||
msg.Free()
|
||||
cc.log.Infof("message queue is full")
|
||||
// todo consider to close the connection
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,8 +179,7 @@ func (c *Client) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
if err := c.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -264,14 +263,14 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
conn, err := ws.Dial(c.connectionURL)
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
err = c.handShake()
|
||||
if err != nil {
|
||||
if err = c.handShake(); err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
@@ -345,7 +344,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
|
7
relay/client/dialer/net/err.go
Normal file
7
relay/client/dialer/net/err.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package net
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
)
|
97
relay/client/dialer/quic/conn.go
Normal file
97
relay/client/dialer/quic/conn.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
const (
|
||||
Network = "quic"
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a Addr) Network() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a Addr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
session quic.Connection
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(session quic.Connection) net.Conn {
|
||||
return &Conn{
|
||||
session: session,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
n = copy(b, dgram)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
if c.session != nil {
|
||||
return c.session.LocalAddr()
|
||||
}
|
||||
return Addr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.session.CloseWithError(0, "normal closure")
|
||||
}
|
||||
|
||||
func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
|
||||
return netErr.ErrClosedByServer
|
||||
}
|
||||
return err
|
||||
}
|
71
relay/client/dialer/quic/quic.go
Normal file
71
relay/client/dialer/quic/quic.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 30 * time.Second,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on UDP: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve UDP address: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(session)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(address, "rels://") {
|
||||
return address[7:], nil
|
||||
}
|
||||
return address[6:], nil
|
||||
}
|
96
relay/client/dialer/race_dialer.go
Normal file
96
relay/client/dialer/race_dialer.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
Conn net.Conn
|
||||
Protocol string
|
||||
Err error
|
||||
}
|
||||
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
return &RaceDial{
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial() (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(context.Background())
|
||||
defer abort()
|
||||
|
||||
for _, dfn := range r.dialerFns {
|
||||
go r.dial(dfn, abortCtx, connChan)
|
||||
}
|
||||
|
||||
go r.processResults(connChan, winnerConn, abort)
|
||||
|
||||
conn, ok := <-winnerConn
|
||||
if !ok {
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
|
||||
var hasWinner bool
|
||||
for i := 0; i < len(r.dialerFns); i++ {
|
||||
dr := <-connChan
|
||||
if dr.Err != nil {
|
||||
if errors.Is(dr.Err, context.Canceled) {
|
||||
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
|
||||
} else {
|
||||
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if hasWinner {
|
||||
if cerr := dr.Conn.Close(); cerr != nil {
|
||||
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
r.log.Infof("successfully dialed via: %s", dr.Protocol)
|
||||
|
||||
abort()
|
||||
hasWinner = true
|
||||
winnerConn <- dr.Conn
|
||||
}
|
||||
close(winnerConn)
|
||||
}
|
252
relay/client/dialer/race_dialer_test.go
Normal file
252
relay/client/dialer/race_dialer_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type MockAddr struct {
|
||||
network string
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
return m.network
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
return "1.2.3.4"
|
||||
}
|
||||
|
||||
// MockDialer is a mock implementation of DialeFn
|
||||
type MockDialer struct {
|
||||
dialFunc func(ctx context.Context, address string) (net.Conn, error)
|
||||
protocolStr string
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return m.dialFunc(ctx, address)
|
||||
}
|
||||
|
||||
func (m *MockDialer) Protocol() string {
|
||||
return m.protocolStr
|
||||
}
|
||||
|
||||
// MockConn implements net.Conn for testing
|
||||
type MockConn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return m.remoteAddr
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRaceDialEmptyDialers(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
rd := NewRaceDial(logger, serverURL)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error with empty dialers, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto := "test-protocol"
|
||||
|
||||
mockConn := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto},
|
||||
}
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn, nil
|
||||
},
|
||||
protocolStr: proto,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn2, nil
|
||||
},
|
||||
protocolStr: "proto2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn.RemoteAddr().Network() != proto2 {
|
||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialTimeout(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
connectionTimeout = 3 * time.Second
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialAllDialersFail(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "protocol1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("second dialer failed")
|
||||
},
|
||||
protocolStr: "protocol2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto1 := "protocol1"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn1 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto1},
|
||||
}
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
time.Sleep(1 * time.Second)
|
||||
return mockConn1, nil
|
||||
},
|
||||
protocolStr: proto1,
|
||||
}
|
||||
|
||||
mock2err := make(chan error)
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
mock2err <- ctx.Err()
|
||||
return mockConn2, ctx.Err()
|
||||
},
|
||||
protocolStr: proto2,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
if conn != mockConn1 {
|
||||
t.Errorf("Expected first connection, got %v", conn)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Errorf("Timed out waiting for second dialer to finish")
|
||||
case err := <-mock2err:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,11 +1,15 @@
|
||||
package ws
|
||||
|
||||
const (
|
||||
Network = "ws"
|
||||
)
|
||||
|
||||
type WebsocketAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) Network() string {
|
||||
return "websocket"
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) String() string {
|
||||
|
@@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
// todo use ErrClosedByServer
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
@@ -2,6 +2,7 @@ package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -15,7 +16,14 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
|
||||
}
|
||||
parsedURL.Path = ws.URLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
|
||||
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user