[client,relay] Add QUIC support (#2962)

This commit is contained in:
Zoltan Papp
2025-01-15 16:28:19 +01:00
committed by GitHub
parent e4a25b6a60
commit 1ffa519387
25 changed files with 943 additions and 33 deletions

View File

@@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer"
"github.com/netbirdio/netbird/relay/client/dialer/quic"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
@@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
}
@@ -179,8 +179,7 @@ func (c *Client) Connect() error {
return nil
}
err := c.connect()
if err != nil {
if err := c.connect(); err != nil {
return err
}
@@ -264,14 +263,14 @@ func (c *Client) Close() error {
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
if err = c.handShake(); err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
@@ -345,7 +344,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
c.log.Errorf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)

View File

@@ -0,0 +1,7 @@
package net
import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
)

View File

@@ -0,0 +1,97 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
)
const (
Network = "quic"
)
type Addr struct {
addr string
}
func (a Addr) Network() string {
return Network
}
func (a Addr) String() string {
return a.addr
}
type Conn struct {
session quic.Connection
ctx context.Context
}
func NewConn(session quic.Connection) net.Conn {
return &Conn{
session: session,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) LocalAddr() net.Addr {
if c.session != nil {
return c.session.LocalAddr()
}
return Addr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
}
func (c *Conn) Close() error {
return c.session.CloseWithError(0, "normal closure")
}
func (c *Conn) remoteCloseErrHandling(err error) error {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return netErr.ErrClosedByServer
}
return err
}

View File

@@ -0,0 +1,71 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/relay/tls"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
conn := NewConn(session)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
if strings.HasPrefix(address, "rels://") {
return address[7:], nil
}
return address[6:], nil
}

View File

@@ -0,0 +1,96 @@
package dialer
import (
"context"
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
Protocol() string
}
type dialResult struct {
Conn net.Conn
Protocol string
Err error
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
defer abort()
for _, dfn := range r.dialerFns {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
}
return conn, nil
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
var hasWinner bool
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
if errors.Is(dr.Err, context.Canceled) {
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
}
continue
}
if hasWinner {
if cerr := dr.Conn.Close(); cerr != nil {
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
}
continue
}
r.log.Infof("successfully dialed via: %s", dr.Protocol)
abort()
hasWinner = true
winnerConn <- dr.Conn
}
close(winnerConn)
}

View File

@@ -0,0 +1,252 @@
package dialer
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockAddr struct {
network string
}
func (m *MockAddr) Network() string {
return m.network
}
func (m *MockAddr) String() string {
return "1.2.3.4"
}
// MockDialer is a mock implementation of DialeFn
type MockDialer struct {
dialFunc func(ctx context.Context, address string) (net.Conn, error)
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}
func (m *MockDialer) Protocol() string {
return m.protocolStr
}
// MockConn implements net.Conn for testing
type MockConn struct {
remoteAddr net.Addr
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return nil
}
func (m *MockConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
}
}
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto := "test-protocol"
mockConn := &MockConn{
remoteAddr: &MockAddr{network: proto},
}
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn, nil
},
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
}
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto2 := "protocol2"
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "proto1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn2, nil
},
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
},
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialAllDialersFail(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "protocol1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("second dialer failed")
},
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto1 := "protocol1"
proto2 := "protocol2"
mockConn1 := &MockConn{
remoteAddr: &MockAddr{network: proto1},
}
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
time.Sleep(1 * time.Second)
return mockConn1, nil
},
protocolStr: proto1,
}
mock2err := make(chan error)
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
mock2err <- ctx.Err()
return mockConn2, ctx.Err()
},
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
if conn != mockConn1 {
t.Errorf("Expected first connection, got %v", conn)
}
select {
case <-time.After(3 * time.Second):
t.Errorf("Timed out waiting for second dialer to finish")
case err := <-mock2err:
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}
}

View File

@@ -1,11 +1,15 @@
package ws
const (
Network = "ws"
)
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return "websocket"
return Network
}
func (a WebsocketAddr) String() string {

View File

@@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
// todo use ErrClosedByServer
return 0, err
}

View File

@@ -2,6 +2,7 @@ package ws
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@@ -15,7 +16,14 @@ import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
@@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}