[client,relay] Add QUIC support (#2962)

This commit is contained in:
Zoltan Papp 2025-01-15 16:28:19 +01:00 committed by GitHub
parent e4a25b6a60
commit 1ffa519387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 943 additions and 33 deletions

View File

@ -44,4 +44,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@ -134,7 +134,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
@ -194,7 +194,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
@ -254,7 +254,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
api_benchmark:
needs: [ build-cache ]

View File

@ -65,7 +65,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

4
go.mod
View File

@ -71,6 +71,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.48.2
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@ -155,11 +156,13 @@ require (
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@ -221,6 +224,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect

6
go.sum
View File

@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer"
"github.com/netbirdio/netbird/relay/client/dialer/quic"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
}
@ -179,8 +179,7 @@ func (c *Client) Connect() error {
return nil
}
err := c.connect()
if err != nil {
if err := c.connect(); err != nil {
return err
}
@ -264,14 +263,14 @@ func (c *Client) Close() error {
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
if err = c.handShake(); err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
@ -345,7 +344,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
c.log.Errorf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)

View File

@ -0,0 +1,7 @@
package net
import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
)

View File

@ -0,0 +1,97 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
)
const (
Network = "quic"
)
type Addr struct {
addr string
}
func (a Addr) Network() string {
return Network
}
func (a Addr) String() string {
return a.addr
}
type Conn struct {
session quic.Connection
ctx context.Context
}
func NewConn(session quic.Connection) net.Conn {
return &Conn{
session: session,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) LocalAddr() net.Addr {
if c.session != nil {
return c.session.LocalAddr()
}
return Addr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
}
func (c *Conn) Close() error {
return c.session.CloseWithError(0, "normal closure")
}
func (c *Conn) remoteCloseErrHandling(err error) error {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return netErr.ErrClosedByServer
}
return err
}

View File

@ -0,0 +1,71 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/relay/tls"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
conn := NewConn(session)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
if strings.HasPrefix(address, "rels://") {
return address[7:], nil
}
return address[6:], nil
}

View File

@ -0,0 +1,96 @@
package dialer
import (
"context"
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
Protocol() string
}
type dialResult struct {
Conn net.Conn
Protocol string
Err error
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
defer abort()
for _, dfn := range r.dialerFns {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
}
return conn, nil
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
var hasWinner bool
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
if errors.Is(dr.Err, context.Canceled) {
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
}
continue
}
if hasWinner {
if cerr := dr.Conn.Close(); cerr != nil {
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
}
continue
}
r.log.Infof("successfully dialed via: %s", dr.Protocol)
abort()
hasWinner = true
winnerConn <- dr.Conn
}
close(winnerConn)
}

View File

@ -0,0 +1,252 @@
package dialer
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockAddr struct {
network string
}
func (m *MockAddr) Network() string {
return m.network
}
func (m *MockAddr) String() string {
return "1.2.3.4"
}
// MockDialer is a mock implementation of DialeFn
type MockDialer struct {
dialFunc func(ctx context.Context, address string) (net.Conn, error)
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}
func (m *MockDialer) Protocol() string {
return m.protocolStr
}
// MockConn implements net.Conn for testing
type MockConn struct {
remoteAddr net.Addr
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return nil
}
func (m *MockConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
}
}
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto := "test-protocol"
mockConn := &MockConn{
remoteAddr: &MockAddr{network: proto},
}
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn, nil
},
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
}
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto2 := "protocol2"
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "proto1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn2, nil
},
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
},
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialAllDialersFail(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "protocol1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("second dialer failed")
},
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto1 := "protocol1"
proto2 := "protocol2"
mockConn1 := &MockConn{
remoteAddr: &MockAddr{network: proto1},
}
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
time.Sleep(1 * time.Second)
return mockConn1, nil
},
protocolStr: proto1,
}
mock2err := make(chan error)
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
mock2err <- ctx.Err()
return mockConn2, ctx.Err()
},
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
if conn != mockConn1 {
t.Errorf("Expected first connection, got %v", conn)
}
select {
case <-time.After(3 * time.Second):
t.Errorf("Timed out waiting for second dialer to finish")
case err := <-mock2err:
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}
}

View File

@ -1,11 +1,15 @@
package ws
const (
Network = "ws"
)
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return "websocket"
return Network
}
func (a WebsocketAddr) String() string {

View File

@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
// todo use ErrClosedByServer
return 0, err
}

View File

@ -2,6 +2,7 @@ package ws
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -15,7 +16,14 @@ import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}

View File

@ -0,0 +1,101 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
)
type Conn struct {
session quic.Connection
closed bool
closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConn(session quic.Connection) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{
session: session,
ctx: ctx,
ctxCancel: cancel,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
// Copy data to b, ensuring we dont exceed the size of b
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
c.closedMu.Unlock()
return nil
}
c.closed = true
c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) remoteCloseErrHandling(err error) error {
if c.isClosed() {
return net.ErrClosed
}
// Check if the connection was closed remotely
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return net.ErrClosed
}
return err
}

View File

@ -0,0 +1,66 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Listener struct {
// Address is the address to listen on
Address string
// TLSConfig is the TLS configuration for the server
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
}
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
if err != nil {
return fmt.Errorf("failed to create QUIC listener: %v", err)
}
l.listener = listener
log.Infof("QUIC server listening on address: %s", l.Address)
for {
session, err := listener.Accept(context.Background())
if err != nil {
if errors.Is(err, quic.ErrServerClosed) {
return nil
}
log.Errorf("Failed to accept QUIC session: %v", err)
continue
}
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
}
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil {
return nil
}
log.Infof("stopping QUIC listener")
if err := l.listener.Close(); err != nil {
return fmt.Errorf("listener shutdown failed: %v", err)
}
log.Infof("QUIC listener stopped")
return nil
}

View File

@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}
log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}

View File

@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) {
func (r *Relay) Shutdown(ctx context.Context) {
log.Infof("close connection with all peers")
r.closeMu.Lock()
defer r.closeMu.Unlock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) {
}
wg.Wait()
r.metricsCancel()
r.closeMu.Unlock()
r.closed = true
}
// InstanceURL returns the instance URL of the relay server

View File

@ -3,13 +3,17 @@ package server
import (
"context"
"crypto/tls"
"sync"
log "github.com/sirupsen/logrus"
"github.com/hashicorp/go-multierror"
"go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls"
)
// ListenerConfig is the configuration for the listener.
@ -24,8 +28,8 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
wSListener listener.Listener
relay *Relay
listeners []listener.Listener
}
// NewServer creates a new relay server instance.
@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV
return nil, err
}
return &Server{
relay: relay,
relay: relay,
listeners: make([]listener.Listener, 0, 2),
}, nil
}
// Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error {
r.wSListener = &ws.Listener{
wSListener := &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
r.listeners = append(r.listeners, wSListener)
wslErr := r.wSListener.Listen(r.relay.Accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
if err != nil {
return err
}
return wslErr
quicListener := &quic.Listener{
Address: cfg.Address,
TLSConfig: tlsConfigQUIC,
}
r.listeners = append(r.listeners, quicListener)
errChan := make(chan error, len(r.listeners))
wg := sync.WaitGroup{}
for _, l := range r.listeners {
wg.Add(1)
go func(listener listener.Listener) {
defer wg.Done()
errChan <- listener.Listen(r.relay.Accept)
}(l)
}
wg.Wait()
close(errChan)
var multiErr *multierror.Error
for err := range errChan {
multiErr = multierror.Append(multiErr, err)
}
return nberrors.FormatErrorOrNil(multiErr)
}
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
// the connections will be forcefully closed.
func (r *Server) Shutdown(ctx context.Context) (err error) {
// stop service new connections
if r.wSListener != nil {
err = r.wSListener.Shutdown(ctx)
}
func (r *Server) Shutdown(ctx context.Context) error {
r.relay.Shutdown(ctx)
return
var multiErr *multierror.Error
for _, l := range r.listeners {
if err := l.Shutdown(ctx); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
return nberrors.FormatErrorOrNil(multiErr)
}
// InstanceURL returns the instance URL of the relay server.

3
relay/tls/alpn.go Normal file
View File

@ -0,0 +1,3 @@
package tls
const nbalpn = "nb-quic"

12
relay/tls/client_dev.go Normal file
View File

@ -0,0 +1,12 @@
//go:build devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
}
}

11
relay/tls/client_prod.go Normal file
View File

@ -0,0 +1,11 @@
//go:build !devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{nbalpn},
}
}

36
relay/tls/doc.go Normal file
View File

@ -0,0 +1,36 @@
// Package tls provides utilities for configuring and managing Transport Layer
// Security (TLS) in server and client environments, with a focus on QUIC
// protocol support and testing configurations.
//
// The package includes functions for cloning and customizing TLS
// configurations as well as generating self-signed certificates for
// development and testing purposes.
//
// Key Features:
//
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
// for QUIC protocol with specified or default settings. QUIC requires a
// specific TLS configuration with proper ALPN (Application-Layer Protocol
// Negotiation) support, making the TLS settings crucial for establishing
// secure connections.
//
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
// for QUIC protocol. The configuration differs between development
// (insecure testing) and production (strict verification).
//
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
// use in local development and testing scenarios.
//
// Usage:
//
// This package provides separate implementations for development and production
// environments. The development implementation (guarded by `//go:build devcert`)
// supports testing configurations with self-signed certificates and insecure
// client connections. The production implementation (guarded by `//go:build
// !devcert`) ensures that valid and secure TLS configurations are supplied
// and used.
//
// The QUIC protocol is highly reliant on properly configured TLS settings,
// and this package ensures that configurations meet the requirements for
// secure and efficient QUIC communication.
package tls

79
relay/tls/server_dev.go Normal file
View File

@ -0,0 +1,79 @@
//go:build devcert
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
log "github.com/sirupsen/logrus"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
log.Warnf("QUIC server will use self signed certificate for testing!")
return generateTestTLSConfig()
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}
// GenerateTestTLSConfig creates a self-signed certificate for testing
func generateTestTLSConfig() (*tls.Config, error) {
log.Infof("generating test TLS config")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn},
}, nil
}

17
relay/tls/server_prod.go Normal file
View File

@ -0,0 +1,17 @@
//go:build !devcert
package tls
import (
"crypto/tls"
"fmt"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}