From 1ffa5193871f9e4c7b14ff0b553d8432e7e280b5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 15 Jan 2025 16:28:19 +0100 Subject: [PATCH] [client,relay] Add QUIC support (#2962) --- .github/workflows/golang-test-darwin.yml | 3 +- .github/workflows/golang-test-linux.yml | 6 +- .github/workflows/golang-test-windows.yml | 2 +- go.mod | 4 + go.sum | 6 + relay/client/client.go | 15 +- relay/client/dialer/net/err.go | 7 + relay/client/dialer/quic/conn.go | 97 +++++++++ relay/client/dialer/quic/quic.go | 71 ++++++ relay/client/dialer/race_dialer.go | 96 +++++++++ relay/client/dialer/race_dialer_test.go | 252 ++++++++++++++++++++++ relay/client/dialer/ws/addr.go | 6 +- relay/client/dialer/ws/conn.go | 1 + relay/client/dialer/ws/ws.go | 15 +- relay/server/listener/quic/conn.go | 101 +++++++++ relay/server/listener/quic/listener.go | 66 ++++++ relay/server/listener/ws/listener.go | 2 + relay/server/relay.go | 4 +- relay/server/server.go | 64 ++++-- relay/tls/alpn.go | 3 + relay/tls/client_dev.go | 12 ++ relay/tls/client_prod.go | 11 + relay/tls/doc.go | 36 ++++ relay/tls/server_dev.go | 79 +++++++ relay/tls/server_prod.go | 17 ++ 25 files changed, 943 insertions(+), 33 deletions(-) create mode 100644 relay/client/dialer/net/err.go create mode 100644 relay/client/dialer/quic/conn.go create mode 100644 relay/client/dialer/quic/quic.go create mode 100644 relay/client/dialer/race_dialer.go create mode 100644 relay/client/dialer/race_dialer_test.go create mode 100644 relay/server/listener/quic/conn.go create mode 100644 relay/server/listener/quic/listener.go create mode 100644 relay/tls/alpn.go create mode 100644 relay/tls/client_dev.go create mode 100644 relay/tls/client_prod.go create mode 100644 relay/tls/doc.go create mode 100644 relay/tls/server_dev.go create mode 100644 relay/tls/server_prod.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2dbeb106a..664e8be18 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -44,4 +44,5 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) + diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 5f7d7b4a3..ba5f66746 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -134,7 +134,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) test_management: needs: [ build-cache ] @@ -194,7 +194,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: needs: [ build-cache ] @@ -254,7 +254,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... api_benchmark: needs: [ build-cache ] diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 3a3c47052..782e4c30a 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -65,7 +65,7 @@ jobs: - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/go.mod b/go.mod index 147577cc3..88bcada07 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/quic-go/quic-go v0.48.2 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -155,11 +156,13 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect @@ -221,6 +224,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect diff --git a/go.sum b/go.sum index 253429798..8ba94dd6a 100644 --- a/go.sum +++ b/go.sum @@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= @@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/relay/client/client.go b/relay/client/client.go index bccd85c93..3c23b70d2 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" auth "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client/dialer" + "github.com/netbirdio/netbird/relay/client/dialer/quic" "github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" @@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) { msg.Free() default: msg.Free() - cc.log.Infof("message queue is full") - // todo consider to close the connection } } @@ -179,8 +179,7 @@ func (c *Client) Connect() error { return nil } - err := c.connect() - if err != nil { + if err := c.connect(); err != nil { return err } @@ -264,14 +263,14 @@ func (c *Client) Close() error { } func (c *Client) connect() error { - conn, err := ws.Dial(c.connectionURL) + rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + conn, err := rd.Dial() if err != nil { return err } c.relayConn = conn - err = c.handShake() - if err != nil { + if err = c.handShake(); err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) @@ -345,7 +344,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { - c.log.Debugf("failed to read message from relay server: %s", errExit) + c.log.Errorf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() c.bufPool.Put(bufPtr) diff --git a/relay/client/dialer/net/err.go b/relay/client/dialer/net/err.go new file mode 100644 index 000000000..fee844963 --- /dev/null +++ b/relay/client/dialer/net/err.go @@ -0,0 +1,7 @@ +package net + +import "errors" + +var ( + ErrClosedByServer = errors.New("closed by server") +) diff --git a/relay/client/dialer/quic/conn.go b/relay/client/dialer/quic/conn.go new file mode 100644 index 000000000..d64633c8c --- /dev/null +++ b/relay/client/dialer/quic/conn.go @@ -0,0 +1,97 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" + + netErr "github.com/netbirdio/netbird/relay/client/dialer/net" +) + +const ( + Network = "quic" +) + +type Addr struct { + addr string +} + +func (a Addr) Network() string { + return Network +} + +func (a Addr) String() string { + return a.addr +} + +type Conn struct { + session quic.Connection + ctx context.Context +} + +func NewConn(session quic.Connection) net.Conn { + return &Conn{ + session: session, + ctx: context.Background(), + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, c.remoteCloseErrHandling(err) + } + + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.session.SendDatagram(b) + if err != nil { + err = c.remoteCloseErrHandling(err) + log.Errorf("failed to write to QUIC stream: %v", err) + return 0, err + } + return len(b), nil +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.session.RemoteAddr() +} + +func (c *Conn) LocalAddr() net.Addr { + if c.session != nil { + return c.session.LocalAddr() + } + return Addr{addr: "unknown"} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("SetReadDeadline is not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return nil +} + +func (c *Conn) Close() error { + return c.session.CloseWithError(0, "normal closure") +} + +func (c *Conn) remoteCloseErrHandling(err error) error { + var appErr *quic.ApplicationError + if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 { + return netErr.ErrClosedByServer + } + return err +} diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go new file mode 100644 index 000000000..593d1334b --- /dev/null +++ b/relay/client/dialer/quic/quic.go @@ -0,0 +1,71 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" + + quictls "github.com/netbirdio/netbird/relay/tls" + nbnet "github.com/netbirdio/netbird/util/net" +) + +type Dialer struct { +} + +func (d Dialer) Protocol() string { + return Network +} + +func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { + quicURL, err := prepareURL(address) + if err != nil { + return nil, err + } + + quicConfig := &quic.Config{ + KeepAlivePeriod: 30 * time.Second, + MaxIdleTimeout: 4 * time.Minute, + EnableDatagrams: true, + } + + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + log.Errorf("failed to listen on UDP: %s", err) + return nil, err + } + + udpAddr, err := net.ResolveUDPAddr("udp", quicURL) + if err != nil { + log.Errorf("failed to resolve UDP address: %s", err) + return nil, err + } + + session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } + log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err) + return nil, err + } + + conn := NewConn(session) + return conn, nil +} + +func prepareURL(address string) (string, error) { + if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") { + return "", fmt.Errorf("unsupported scheme: %s", address) + } + + if strings.HasPrefix(address, "rels://") { + return address[7:], nil + } + return address[6:], nil +} diff --git a/relay/client/dialer/race_dialer.go b/relay/client/dialer/race_dialer.go new file mode 100644 index 000000000..11dba5799 --- /dev/null +++ b/relay/client/dialer/race_dialer.go @@ -0,0 +1,96 @@ +package dialer + +import ( + "context" + "errors" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + connectionTimeout = 30 * time.Second +) + +type DialeFn interface { + Dial(ctx context.Context, address string) (net.Conn, error) + Protocol() string +} + +type dialResult struct { + Conn net.Conn + Protocol string + Err error +} + +type RaceDial struct { + log *log.Entry + serverURL string + dialerFns []DialeFn +} + +func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { + return &RaceDial{ + log: log, + serverURL: serverURL, + dialerFns: dialerFns, + } +} + +func (r *RaceDial) Dial() (net.Conn, error) { + connChan := make(chan dialResult, len(r.dialerFns)) + winnerConn := make(chan net.Conn, 1) + abortCtx, abort := context.WithCancel(context.Background()) + defer abort() + + for _, dfn := range r.dialerFns { + go r.dial(dfn, abortCtx, connChan) + } + + go r.processResults(connChan, winnerConn, abort) + + conn, ok := <-winnerConn + if !ok { + return nil, errors.New("failed to dial to Relay server on any protocol") + } + return conn, nil +} + +func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { + ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) + defer cancel() + + r.log.Infof("dialing Relay server via %s", dfn.Protocol()) + conn, err := dfn.Dial(ctx, r.serverURL) + connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} +} + +func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) { + var hasWinner bool + for i := 0; i < len(r.dialerFns); i++ { + dr := <-connChan + if dr.Err != nil { + if errors.Is(dr.Err, context.Canceled) { + r.log.Infof("connection attempt aborted via: %s", dr.Protocol) + } else { + r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err) + } + continue + } + + if hasWinner { + if cerr := dr.Conn.Close(); cerr != nil { + r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr) + } + continue + } + + r.log.Infof("successfully dialed via: %s", dr.Protocol) + + abort() + hasWinner = true + winnerConn <- dr.Conn + } + close(winnerConn) +} diff --git a/relay/client/dialer/race_dialer_test.go b/relay/client/dialer/race_dialer_test.go new file mode 100644 index 000000000..989abb0a6 --- /dev/null +++ b/relay/client/dialer/race_dialer_test.go @@ -0,0 +1,252 @@ +package dialer + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +type MockAddr struct { + network string +} + +func (m *MockAddr) Network() string { + return m.network +} + +func (m *MockAddr) String() string { + return "1.2.3.4" +} + +// MockDialer is a mock implementation of DialeFn +type MockDialer struct { + dialFunc func(ctx context.Context, address string) (net.Conn, error) + protocolStr string +} + +func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) { + return m.dialFunc(ctx, address) +} + +func (m *MockDialer) Protocol() string { + return m.protocolStr +} + +// MockConn implements net.Conn for testing +type MockConn struct { + remoteAddr net.Addr +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func (m *MockConn) Close() error { + return nil +} + +func (m *MockConn) LocalAddr() net.Addr { + return nil +} + +func (m *MockConn) RemoteAddr() net.Addr { + return m.remoteAddr +} + +func (m *MockConn) SetDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestRaceDialEmptyDialers(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + rd := NewRaceDial(logger, serverURL) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error with empty dialers, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection with empty dialers, got %v", conn) + } +} + +func TestRaceDialSingleSuccessfulDialer(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto := "test-protocol" + + mockConn := &MockConn{ + remoteAddr: &MockAddr{network: proto}, + } + + mockDialer := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return mockConn, nil + }, + protocolStr: proto, + } + + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn == nil { + t.Errorf("Expected non-nil connection") + } +} + +func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto2 := "protocol2" + + mockConn2 := &MockConn{ + remoteAddr: &MockAddr{network: proto2}, + } + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("first dialer failed") + }, + protocolStr: "proto1", + } + + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return mockConn2, nil + }, + protocolStr: "proto2", + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn.RemoteAddr().Network() != proto2 { + t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) + } +} + +func TestRaceDialTimeout(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + connectionTimeout = 3 * time.Second + mockDialer := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + protocolStr: "proto1", + } + + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection, got %v", conn) + } +} + +func TestRaceDialAllDialersFail(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("first dialer failed") + }, + protocolStr: "protocol1", + } + + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("second dialer failed") + }, + protocolStr: "protocol2", + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection, got %v", conn) + } +} + +func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto1 := "protocol1" + proto2 := "protocol2" + + mockConn1 := &MockConn{ + remoteAddr: &MockAddr{network: proto1}, + } + + mockConn2 := &MockConn{ + remoteAddr: &MockAddr{network: proto2}, + } + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + time.Sleep(1 * time.Second) + return mockConn1, nil + }, + protocolStr: proto1, + } + + mock2err := make(chan error) + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + <-ctx.Done() + mock2err <- ctx.Err() + return mockConn2, ctx.Err() + }, + protocolStr: proto2, + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn == nil { + t.Errorf("Expected non-nil connection") + } + if conn != mockConn1 { + t.Errorf("Expected first connection, got %v", conn) + } + + select { + case <-time.After(3 * time.Second): + t.Errorf("Timed out waiting for second dialer to finish") + case err := <-mock2err: + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got %v", err) + } + } +} diff --git a/relay/client/dialer/ws/addr.go b/relay/client/dialer/ws/addr.go index 43f5dd6af..11158cfbd 100644 --- a/relay/client/dialer/ws/addr.go +++ b/relay/client/dialer/ws/addr.go @@ -1,11 +1,15 @@ package ws +const ( + Network = "ws" +) + type WebsocketAddr struct { addr string } func (a WebsocketAddr) Network() string { - return "websocket" + return Network } func (a WebsocketAddr) String() string { diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go index e7f771b8d..74bcafd82 100644 --- a/relay/client/dialer/ws/conn.go +++ b/relay/client/dialer/ws/conn.go @@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { func (c *Conn) Read(b []byte) (n int, err error) { t, ioReader, err := c.Conn.Reader(c.ctx) if err != nil { + // todo use ErrClosedByServer return 0, err } diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index d9388aafd..df91a66d4 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -2,6 +2,7 @@ package ws import ( "context" + "errors" "fmt" "net" "net/http" @@ -15,7 +16,14 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func Dial(address string) (net.Conn, error) { +type Dialer struct { +} + +func (d Dialer) Protocol() string { + return "WS" +} + +func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { wsURL, err := prepareURL(address) if err != nil { return nil, err @@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) { } parsedURL.Path = ws.URLPath - wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) + wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts) if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) return nil, err } diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go new file mode 100644 index 000000000..909ec1cc6 --- /dev/null +++ b/relay/server/listener/quic/conn.go @@ -0,0 +1,101 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/quic-go/quic-go" +) + +type Conn struct { + session quic.Connection + closed bool + closedMu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc +} + +func NewConn(session quic.Connection) *Conn { + ctx, cancel := context.WithCancel(context.Background()) + return &Conn{ + session: session, + ctx: ctx, + ctxCancel: cancel, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, c.remoteCloseErrHandling(err) + } + // Copy data to b, ensuring we don’t exceed the size of b + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + if err := c.session.SendDatagram(b); err != nil { + return 0, c.remoteCloseErrHandling(err) + } + return len(b), nil +} + +func (c *Conn) LocalAddr() net.Addr { + return c.session.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.session.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return fmt.Errorf("SetDeadline is not implemented") +} + +func (c *Conn) Close() error { + c.closedMu.Lock() + if c.closed { + c.closedMu.Unlock() + return nil + } + c.closed = true + c.closedMu.Unlock() + + c.ctxCancel() // Cancel the context + + sessionErr := c.session.CloseWithError(0, "normal closure") + return sessionErr +} + +func (c *Conn) isClosed() bool { + c.closedMu.Lock() + defer c.closedMu.Unlock() + return c.closed +} + +func (c *Conn) remoteCloseErrHandling(err error) error { + if c.isClosed() { + return net.ErrClosed + } + + // Check if the connection was closed remotely + var appErr *quic.ApplicationError + if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 { + return net.ErrClosed + } + + return err +} diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go new file mode 100644 index 000000000..b6e01994f --- /dev/null +++ b/relay/server/listener/quic/listener.go @@ -0,0 +1,66 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +type Listener struct { + // Address is the address to listen on + Address string + // TLSConfig is the TLS configuration for the server + TLSConfig *tls.Config + + listener *quic.Listener + acceptFn func(conn net.Conn) +} + +func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + l.acceptFn = acceptFn + + quicCfg := &quic.Config{ + EnableDatagrams: true, + } + listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg) + if err != nil { + return fmt.Errorf("failed to create QUIC listener: %v", err) + } + + l.listener = listener + log.Infof("QUIC server listening on address: %s", l.Address) + + for { + session, err := listener.Accept(context.Background()) + if err != nil { + if errors.Is(err, quic.ErrServerClosed) { + return nil + } + + log.Errorf("Failed to accept QUIC session: %v", err) + continue + } + + log.Infof("QUIC client connected from: %s", session.RemoteAddr()) + conn := NewConn(session) + l.acceptFn(conn) + } +} + +func (l *Listener) Shutdown(ctx context.Context) error { + if l.listener == nil { + return nil + } + + log.Infof("stopping QUIC listener") + if err := l.listener.Close(); err != nil { + return fmt.Errorf("listener shutdown failed: %v", err) + } + log.Infof("QUIC listener stopped") + return nil +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 5c62c0826..0eb244c77 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } + log.Infof("WS client connected from: %s", rAddr) + conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) } diff --git a/relay/server/relay.go b/relay/server/relay.go index 6cd8506ae..a5e77bc61 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Shutdown(ctx context.Context) { log.Infof("close connection with all peers") r.closeMu.Lock() + defer r.closeMu.Unlock() + wg := sync.WaitGroup{} peers := r.store.Peers() for _, peer := range peers { @@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) { } wg.Wait() r.metricsCancel() - r.closeMu.Unlock() + r.closed = true } // InstanceURL returns the instance URL of the relay server diff --git a/relay/server/server.go b/relay/server/server.go index 0036e2390..cacc3dafb 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,13 +3,17 @@ package server import ( "context" "crypto/tls" + "sync" - log "github.com/sirupsen/logrus" + "github.com/hashicorp/go-multierror" "go.opentelemetry.io/otel/metric" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" + "github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/ws" + quictls "github.com/netbirdio/netbird/relay/tls" ) // ListenerConfig is the configuration for the listener. @@ -24,8 +28,8 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - relay *Relay - wSListener listener.Listener + relay *Relay + listeners []listener.Listener } // NewServer creates a new relay server instance. @@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV return nil, err } return &Server{ - relay: relay, + relay: relay, + listeners: make([]listener.Listener, 0, 2), }, nil } // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.wSListener = &ws.Listener{ + wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, } + r.listeners = append(r.listeners, wSListener) - wslErr := r.wSListener.Listen(r.relay.Accept) - if wslErr != nil { - log.Errorf("failed to bind ws server: %s", wslErr) + tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig) + if err != nil { + return err } - return wslErr + quicListener := &quic.Listener{ + Address: cfg.Address, + TLSConfig: tlsConfigQUIC, + } + + r.listeners = append(r.listeners, quicListener) + + errChan := make(chan error, len(r.listeners)) + wg := sync.WaitGroup{} + for _, l := range r.listeners { + wg.Add(1) + go func(listener listener.Listener) { + defer wg.Done() + errChan <- listener.Listen(r.relay.Accept) + }(l) + } + + wg.Wait() + close(errChan) + var multiErr *multierror.Error + for err := range errChan { + multiErr = multierror.Append(multiErr, err) + } + + return nberrors.FormatErrorOrNil(multiErr) } // Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context, // the connections will be forcefully closed. -func (r *Server) Shutdown(ctx context.Context) (err error) { - // stop service new connections - if r.wSListener != nil { - err = r.wSListener.Shutdown(ctx) - } - +func (r *Server) Shutdown(ctx context.Context) error { r.relay.Shutdown(ctx) - return + + var multiErr *multierror.Error + for _, l := range r.listeners { + if err := l.Shutdown(ctx); err != nil { + multiErr = multierror.Append(multiErr, err) + } + } + return nberrors.FormatErrorOrNil(multiErr) } // InstanceURL returns the instance URL of the relay server. diff --git a/relay/tls/alpn.go b/relay/tls/alpn.go new file mode 100644 index 000000000..29497d401 --- /dev/null +++ b/relay/tls/alpn.go @@ -0,0 +1,3 @@ +package tls + +const nbalpn = "nb-quic" diff --git a/relay/tls/client_dev.go b/relay/tls/client_dev.go new file mode 100644 index 000000000..f6b8290a0 --- /dev/null +++ b/relay/tls/client_dev.go @@ -0,0 +1,12 @@ +//go:build devcert + +package tls + +import "crypto/tls" + +func ClientQUICTLSConfig() *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // Debug mode allows insecure connections + NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN + } +} diff --git a/relay/tls/client_prod.go b/relay/tls/client_prod.go new file mode 100644 index 000000000..686093a37 --- /dev/null +++ b/relay/tls/client_prod.go @@ -0,0 +1,11 @@ +//go:build !devcert + +package tls + +import "crypto/tls" + +func ClientQUICTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: []string{nbalpn}, + } +} diff --git a/relay/tls/doc.go b/relay/tls/doc.go new file mode 100644 index 000000000..38b807f84 --- /dev/null +++ b/relay/tls/doc.go @@ -0,0 +1,36 @@ +// Package tls provides utilities for configuring and managing Transport Layer +// Security (TLS) in server and client environments, with a focus on QUIC +// protocol support and testing configurations. +// +// The package includes functions for cloning and customizing TLS +// configurations as well as generating self-signed certificates for +// development and testing purposes. +// +// Key Features: +// +// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored +// for QUIC protocol with specified or default settings. QUIC requires a +// specific TLS configuration with proper ALPN (Application-Layer Protocol +// Negotiation) support, making the TLS settings crucial for establishing +// secure connections. +// +// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable +// for QUIC protocol. The configuration differs between development +// (insecure testing) and production (strict verification). +// +// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for +// use in local development and testing scenarios. +// +// Usage: +// +// This package provides separate implementations for development and production +// environments. The development implementation (guarded by `//go:build devcert`) +// supports testing configurations with self-signed certificates and insecure +// client connections. The production implementation (guarded by `//go:build +// !devcert`) ensures that valid and secure TLS configurations are supplied +// and used. +// +// The QUIC protocol is highly reliant on properly configured TLS settings, +// and this package ensures that configurations meet the requirements for +// secure and efficient QUIC communication. +package tls diff --git a/relay/tls/server_dev.go b/relay/tls/server_dev.go new file mode 100644 index 000000000..1a01658fc --- /dev/null +++ b/relay/tls/server_dev.go @@ -0,0 +1,79 @@ +//go:build devcert + +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { + if originTLSCfg == nil { + log.Warnf("QUIC server will use self signed certificate for testing!") + return generateTestTLSConfig() + } + + cfg := originTLSCfg.Clone() + cfg.NextProtos = []string{nbalpn} + return cfg, nil +} + +// GenerateTestTLSConfig creates a self-signed certificate for testing +func generateTestTLSConfig() (*tls.Config, error) { + log.Infof("generating test TLS config") + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + // Create certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{nbalpn}, + }, nil +} diff --git a/relay/tls/server_prod.go b/relay/tls/server_prod.go new file mode 100644 index 000000000..9d1c47d88 --- /dev/null +++ b/relay/tls/server_prod.go @@ -0,0 +1,17 @@ +//go:build !devcert + +package tls + +import ( + "crypto/tls" + "fmt" +) + +func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { + if originTLSCfg == nil { + return nil, fmt.Errorf("valid TLS config is required for QUIC listener") + } + cfg := originTLSCfg.Clone() + cfg.NextProtos = []string{nbalpn} + return cfg, nil +}