From b82b4a07fc948b6e9c77ddc12104163957996c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Sat, 2 Nov 2024 22:55:41 +0100 Subject: [PATCH] Quic support --- go.mod | 4 + go.sum | 6 ++ relay/client/client.go | 8 +- relay/client/dialer/quic/conn.go | 88 +++++++++++++++++++ relay/client/dialer/quic/quic.go | 56 +++++++++++++ relay/client/dialer/ws/ws.go | 7 ++ relay/cmd/root.go | 68 +++++++++++++++ relay/server/listener/quic/conn.go | 112 +++++++++++++++++++++++++ relay/server/listener/quic/listener.go | 70 ++++++++++++++++ relay/server/server.go | 4 +- 10 files changed, 415 insertions(+), 8 deletions(-) create mode 100644 relay/client/dialer/quic/conn.go create mode 100644 relay/client/dialer/quic/quic.go create mode 100644 relay/server/listener/quic/conn.go create mode 100644 relay/server/listener/quic/listener.go diff --git a/go.mod b/go.mod index 0a16753ea..75afcbb35 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/quic-go/quic-go v0.48.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -151,11 +152,13 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.1.0 // indirect github.com/go-text/typesetting v0.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect @@ -216,6 +219,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect diff --git a/go.sum b/go.sum index a4d7ea7f9..4a67b5587 100644 --- a/go.sum +++ b/go.sum @@ -400,6 +400,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -605,6 +606,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= +github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -753,6 +756,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= @@ -963,6 +968,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/relay/client/client.go b/relay/client/client.go index 154c1787f..c8dfef617 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" auth "github.com/netbirdio/netbird/relay/auth/hmac" - "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/client/dialer/quic" "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" ) @@ -93,10 +93,6 @@ func (cc *connContainer) writeMsg(msg Msg) { case cc.messages <- msg: case <-cc.ctx.Done(): msg.Free() - default: - msg.Free() - cc.log.Infof("message queue is full") - // todo consider to close the connection } } @@ -264,7 +260,7 @@ func (c *Client) Close() error { } func (c *Client) connect() error { - conn, err := ws.Dial(c.connectionURL) + conn, err := quic.Dial(c.connectionURL) if err != nil { return err } diff --git a/relay/client/dialer/quic/conn.go b/relay/client/dialer/quic/conn.go new file mode 100644 index 000000000..39e043b1e --- /dev/null +++ b/relay/client/dialer/quic/conn.go @@ -0,0 +1,88 @@ +package quic + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/quic-go/quic-go" +) + +type QuicAddr struct { + addr string +} + +func (a QuicAddr) Network() string { + return "quic" +} + +func (a QuicAddr) String() string { + return a.addr +} + +type Conn struct { + session quic.Connection + remoteAddr QuicAddr + ctx context.Context +} + +func NewConn(session quic.Connection, serverAddress string) net.Conn { + return &Conn{ + session: session, + remoteAddr: QuicAddr{addr: serverAddress}, + ctx: context.Background(), + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + // Use the QUIC stream's Read method directly + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, fmt.Errorf("failed to read from QUIC stream: %v", err) + } + + // Copy data to b, ensuring we don’t exceed the size of b + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + // Use the QUIC stream's Write method directly + err := c.session.SendDatagram(b) + if err != nil { + return 0, fmt.Errorf("failed to write to QUIC stream: %v", err) + } + return len(b), nil +} + +func (c *Conn) RemoteAddr() net.Addr { + if c.session != nil { + return c.session.RemoteAddr() + } + return c.remoteAddr +} + +func (c *Conn) LocalAddr() net.Addr { + if c.session != nil { + return c.session.LocalAddr() + } + return QuicAddr{addr: "unknown"} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetDeadline(t time.Time) error { + + return nil +} + +func (c *Conn) Close() error { + return c.session.CloseWithError(0, "normal closure") +} diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go new file mode 100644 index 000000000..201aa7ea6 --- /dev/null +++ b/relay/client/dialer/quic/quic.go @@ -0,0 +1,56 @@ +package quic + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "strings" + "time" + + "github.com/quic-go/quic-go" +) + +const ( + dialTimeout = 30 * time.Second +) + +func Dial(address string) (net.Conn, error) { + quicURL, err := prepareURL(address) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + + tlsConf := &tls.Config{ + InsecureSkipVerify: true, // Set to true only for testing + NextProtos: []string{"netbird-relay"}, // Ensure this matches the server's ALPN + } + + quicConfig := &quic.Config{ + KeepAlivePeriod: 15 * time.Second, + MaxIdleTimeout: 60 * time.Second, + EnableDatagrams: true, + } + + session, err := quic.DialAddr(ctx, quicURL, tlsConf, quicConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial QUIC server '%s': %v", quicURL, err) + } + + conn := NewConn(session, address) + return conn, nil +} + +func prepareURL(address string) (string, error) { + if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") { + return "", fmt.Errorf("unsupported scheme: %s", address) + } + + if strings.HasPrefix(address, "rels://") { + return address[7:], nil + } + return address[6:], nil +} diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index d9388aafd..227d6953d 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -2,6 +2,7 @@ package ws import ( "context" + "crypto/tls" "fmt" "net" "net/http" @@ -31,6 +32,8 @@ func Dial(address string) (net.Conn, error) { } parsedURL.Path = ws.URLPath + log.Infof("------ Dialing to Relay server: %s", wsURL) + wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) if err != nil { log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) @@ -59,6 +62,10 @@ func httpClientNbDialer() *http.Client { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return customDialer.DialContext(ctx, network, addr) }, + // Set up a TLS configuration that skips certificate verification + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // This accepts invalid TLS certificates + }, } return &http.Client{ diff --git a/relay/cmd/root.go b/relay/cmd/root.go index d603ff73b..7af536a61 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -2,10 +2,17 @@ package cmd import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/sha256" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" + "math/big" + "net" "net/http" "os" "os/signal" @@ -141,6 +148,13 @@ func execute(cmd *cobra.Command, args []string) error { hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) + tlsSupport = true + srvListenerCfg.TLSConfig, err = generateTestTLSConfig() + if err != nil { + log.Debugf("failed to generate test TLS config: %s", err) + return fmt.Errorf("failed to generate test TLS config: %s", err) + } + srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) if err != nil { log.Debugf("failed to create relay server: %v", err) @@ -213,3 +227,57 @@ func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string } return certManager.TLSConfig(), nil } + +// GenerateTestTLSConfig creates a self-signed certificate for testing +func generateTestTLSConfig() (*tls.Config, error) { + // Generate private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + // Create certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + // Create certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + // Encode certificate and private key to PEM format + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + // Create TLS certificate + tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"netbird-relay"}, // Your application protocol + }, nil +} diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go new file mode 100644 index 000000000..82353da9b --- /dev/null +++ b/relay/server/listener/quic/conn.go @@ -0,0 +1,112 @@ +package quic + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/quic-go/quic-go" +) + +const ( + writeTimeout = 10 * time.Second +) + +type Conn struct { + session quic.Connection + closed bool + closedMu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc +} + +func NewConn(session quic.Connection) *Conn { + ctx, cancel := context.WithCancel(context.Background()) + return &Conn{ + session: session, + ctx: ctx, + ctxCancel: cancel, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + if c.isClosed() { + return 0, io.EOF + } + + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, c.ioErrHandling(err) + } + // Copy data to b, ensuring we don’t exceed the size of b + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.session.SendDatagram(b) + return len(b), err +} + +func (c *Conn) LocalAddr() net.Addr { + return c.session.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.session.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetDeadline(t time.Time) error { + return nil +} + +func (c *Conn) Close() error { + c.closedMu.Lock() + if c.closed { + c.closedMu.Unlock() + return nil + } + c.closed = true + c.closedMu.Unlock() + + c.ctxCancel() // Cancel the context + + sessionErr := c.session.CloseWithError(0, "normal closure") + return sessionErr +} + +func (c *Conn) isClosed() bool { + c.closedMu.Lock() + defer c.closedMu.Unlock() + return c.closed +} + +func (c *Conn) ioErrHandling(err error) error { + if c.isClosed() { + return io.EOF + } + + // Handle QUIC-specific errors + if err == nil { + return nil + } + + // Check if the connection was closed remotely + var appErr *quic.ApplicationError + if errors.As(err, &appErr) && appErr.ErrorCode == 0 { // 0 is normal closure + return io.EOF + } + + return err +} diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go new file mode 100644 index 000000000..3b55409f4 --- /dev/null +++ b/relay/server/listener/quic/listener.go @@ -0,0 +1,70 @@ +package quic + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +type Listener struct { + // Address is the address to listen on + Address string + // TLSConfig is the TLS configuration for the server + TLSConfig *tls.Config + + listener *quic.Listener + acceptFn func(conn net.Conn) +} + +func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + l.acceptFn = acceptFn + + quicCfg := &quic.Config{ + EnableDatagrams: true, + } + listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg) + if err != nil { + return fmt.Errorf("failed to create QUIC listener: %v", err) + } + + l.listener = listener + log.Infof("QUIC server listening on address: %s", l.Address) + + for { + session, err := listener.Accept(context.Background()) + if err != nil { + // Check if the listener was closed intentionally + if err.Error() == "server closed" { + return nil + } + log.Errorf("Failed to accept QUIC session: %v", err) + continue + } + + // Handle each session in a separate goroutine + go l.handleSession(session) + } +} + +func (l *Listener) handleSession(session quic.Connection) { + conn := NewConn(session) + l.acceptFn(conn) +} + +func (l *Listener) Shutdown(ctx context.Context) error { + if l.listener == nil { + return nil + } + + log.Infof("stopping QUIC listener") + err := l.listener.Close() + if err != nil { + return fmt.Errorf("listener shutdown failed: %v", err) + } + log.Infof("QUIC listener stopped") + return nil +} diff --git a/relay/server/server.go b/relay/server/server.go index 0036e2390..456dc1ea6 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -9,7 +9,7 @@ import ( "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" - "github.com/netbirdio/netbird/relay/server/listener/ws" + "github.com/netbirdio/netbird/relay/server/listener/quic" ) // ListenerConfig is the configuration for the listener. @@ -45,7 +45,7 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.wSListener = &ws.Listener{ + r.wSListener = &quic.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, }