mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
[server] Add health check HTTP endpoint for Relay server (#4297)
The health check endpoint listens on a dedicated HTTP server. By default, it is available at 0.0.0.0:9000/health. This can be configured using the --health-listen-address flag. The results are cached for 3 seconds to avoid excessive calls. The health check performs the following: Checks the number of active listeners. Validates each listener via WebSocket and QUIC dials, including TLS certificate verification.
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,8 +18,9 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
|
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||||
"github.com/netbirdio/netbird/signal/metrics"
|
"github.com/netbirdio/netbird/signal/metrics"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@@ -40,6 +42,7 @@ type Config struct {
|
|||||||
AuthSecret string
|
AuthSecret string
|
||||||
LogLevel string
|
LogLevel string
|
||||||
LogFile string
|
LogFile string
|
||||||
|
HealthcheckListenAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) Validate() error {
|
func (c Config) Validate() error {
|
||||||
@@ -87,6 +90,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
|
||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
||||||
|
|
||||||
setFlagsFromEnvVars(rootCmd)
|
setFlagsFromEnvVars(rootCmd)
|
||||||
}
|
}
|
||||||
@@ -102,6 +106,7 @@ func waitForExitSignal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func execute(cmd *cobra.Command, args []string) error {
|
func execute(cmd *cobra.Command, args []string) error {
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
err := cobraConfig.Validate()
|
err := cobraConfig.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("invalid config: %s", err)
|
log.Debugf("invalid config: %s", err)
|
||||||
@@ -120,7 +125,9 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("setup metrics: %v", err)
|
return fmt.Errorf("setup metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("Failed to start metrics server: %v", err)
|
log.Fatalf("Failed to start metrics server: %v", err)
|
||||||
@@ -154,12 +161,31 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to create relay server: %v", err)
|
return fmt.Errorf("failed to create relay server: %v", err)
|
||||||
}
|
}
|
||||||
log.Infof("server will be available on: %s", srv.InstanceURL())
|
log.Infof("server will be available on: %s", srv.InstanceURL())
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||||
log.Fatalf("failed to bind server: %s", err)
|
log.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
hCfg := healthcheck.Config{
|
||||||
|
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||||
|
ServiceChecker: srv,
|
||||||
|
}
|
||||||
|
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create healthcheck server: %v", err)
|
||||||
|
return fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("Failed to start healthcheck server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// it will block until exit signal
|
// it will block until exit signal
|
||||||
waitForExitSignal()
|
waitForExitSignal()
|
||||||
|
|
||||||
@@ -167,6 +193,10 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var shutDownErrors error
|
var shutDownErrors error
|
||||||
|
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||||
|
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
||||||
}
|
}
|
||||||
@@ -175,6 +205,8 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
return shutDownErrors
|
return shutDownErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
195
relay/healthcheck/healthcheck.go
Normal file
195
relay/healthcheck/healthcheck.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package healthcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
statusHealthy = "healthy"
|
||||||
|
statusUnhealthy = "unhealthy"
|
||||||
|
|
||||||
|
path = "/health"
|
||||||
|
|
||||||
|
cacheTTL = 3 * time.Second // Cache TTL for health status
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServiceChecker interface {
|
||||||
|
ListenerProtocols() []protocol.Protocol
|
||||||
|
ListenAddress() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type HealthStatus struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Listeners []protocol.Protocol `json:"listeners"`
|
||||||
|
CertificateValid bool `json:"certificate_valid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
ListenAddress string
|
||||||
|
ServiceChecker ServiceChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
config Config
|
||||||
|
httpServer *http.Server
|
||||||
|
|
||||||
|
cacheMu sync.Mutex
|
||||||
|
cacheStatus *HealthStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(config Config) (*Server, error) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
if config.ServiceChecker == nil {
|
||||||
|
return nil, errors.New("service checker is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
config: config,
|
||||||
|
httpServer: &http.Server{
|
||||||
|
Addr: config.ListenAddress,
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
IdleTimeout: 15 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mux.HandleFunc(path, server.handleHealthcheck)
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ListenAndServe() error {
|
||||||
|
log.Infof("starting healthcheck server on: http://%s%s", dialAddress(s.config.ListenAddress), path)
|
||||||
|
return s.httpServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the healthcheck server
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
log.Info("Shutting down healthcheck server")
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleHealthcheck(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
status *HealthStatus
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
// Cache check
|
||||||
|
s.cacheMu.Lock()
|
||||||
|
status = s.cacheStatus
|
||||||
|
s.cacheMu.Unlock()
|
||||||
|
|
||||||
|
if status != nil && time.Since(status.Timestamp) <= cacheTTL {
|
||||||
|
ok = status.Status == statusHealthy
|
||||||
|
} else {
|
||||||
|
status, ok = s.getHealthStatus(ctx)
|
||||||
|
// Update cache
|
||||||
|
s.cacheMu.Lock()
|
||||||
|
s.cacheStatus = status
|
||||||
|
s.cacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
if err := encoder.Encode(status); err != nil {
|
||||||
|
log.Errorf("Failed to encode healthcheck response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
|
||||||
|
healthy := true
|
||||||
|
status := &HealthStatus{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Status: statusHealthy,
|
||||||
|
CertificateValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
listeners, ok := s.validateListeners()
|
||||||
|
if !ok {
|
||||||
|
status.Status = statusUnhealthy
|
||||||
|
healthy = false
|
||||||
|
}
|
||||||
|
status.Listeners = listeners
|
||||||
|
|
||||||
|
if ok := s.validateCertificate(ctx); !ok {
|
||||||
|
status.Status = statusUnhealthy
|
||||||
|
status.CertificateValid = false
|
||||||
|
healthy = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
|
||||||
|
listeners := s.config.ServiceChecker.ListenerProtocols()
|
||||||
|
if len(listeners) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return listeners, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) validateCertificate(ctx context.Context) bool {
|
||||||
|
listenAddress := s.config.ServiceChecker.ListenAddress()
|
||||||
|
if listenAddress == "" {
|
||||||
|
log.Warn("listen address is empty")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dAddr := dialAddress(listenAddress)
|
||||||
|
|
||||||
|
for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
|
||||||
|
switch proto {
|
||||||
|
case ws.Proto:
|
||||||
|
if err := dialWS(ctx, dAddr); err != nil {
|
||||||
|
log.Errorf("failed to dial WebSocket listener: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case quic.Proto:
|
||||||
|
if err := dialQUIC(ctx, dAddr); err != nil {
|
||||||
|
log.Errorf("failed to dial QUIC listener: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Warnf("unknown protocol for healthcheck: %s", proto)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialAddress(listenAddress string) string {
|
||||||
|
host, port, err := net.SplitHostPort(listenAddress)
|
||||||
|
if err != nil {
|
||||||
|
return listenAddress // fallback, might be invalid for dialing
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" || host == "::" || host == "0.0.0.0" {
|
||||||
|
host = "0.0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.JoinHostPort(host, port)
|
||||||
|
}
|
31
relay/healthcheck/quic.go
Normal file
31
relay/healthcheck/quic.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package healthcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go"
|
||||||
|
|
||||||
|
tlsnb "github.com/netbirdio/netbird/shared/relay/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialQUIC(ctx context.Context, address string) error {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: false, // Keep certificate validation enabled
|
||||||
|
NextProtos: []string{tlsnb.NBalpn},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{
|
||||||
|
MaxIdleTimeout: 30 * time.Second,
|
||||||
|
KeepAlivePeriod: 10 * time.Second,
|
||||||
|
EnableDatagrams: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to QUIC server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.CloseWithError(0, "availability check complete")
|
||||||
|
return nil
|
||||||
|
}
|
28
relay/healthcheck/ws.go
Normal file
28
relay/healthcheck/ws.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package healthcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/relay"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialWS(ctx context.Context, address string) error {
|
||||||
|
url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath)
|
||||||
|
|
||||||
|
conn, resp, err := websocket.Dial(ctx, url, nil)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to websocket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
|
||||||
|
return nil
|
||||||
|
}
|
3
relay/protocol/protocol.go
Normal file
3
relay/protocol/protocol.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
type Protocol string
|
@@ -3,9 +3,12 @@ package listener
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener interface {
|
type Listener interface {
|
||||||
Listen(func(conn net.Conn)) error
|
Listen(func(conn net.Conn)) error
|
||||||
Shutdown(ctx context.Context) error
|
Shutdown(ctx context.Context) error
|
||||||
|
Protocol() protocol.Protocol
|
||||||
}
|
}
|
||||||
|
@@ -9,8 +9,12 @@ import (
|
|||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const Proto protocol.Protocol = "quic"
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
// Address is the address to listen on
|
// Address is the address to listen on
|
||||||
Address string
|
Address string
|
||||||
@@ -50,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Protocol() protocol.Protocol {
|
||||||
|
return Proto
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Listener) Shutdown(ctx context.Context) error {
|
func (l *Listener) Shutdown(ctx context.Context) error {
|
||||||
if l.listener == nil {
|
if l.listener == nil {
|
||||||
return nil
|
return nil
|
||||||
|
@@ -11,11 +11,14 @@ import (
|
|||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
"github.com/netbirdio/netbird/shared/relay"
|
"github.com/netbirdio/netbird/shared/relay"
|
||||||
)
|
)
|
||||||
|
|
||||||
// URLPath is the path for the websocket connection.
|
const (
|
||||||
const URLPath = relay.WebSocketURLPath
|
Proto protocol.Protocol = "ws"
|
||||||
|
URLPath = relay.WebSocketURLPath
|
||||||
|
)
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
// Address is the address to listen on.
|
// Address is the address to listen on.
|
||||||
@@ -51,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Protocol() protocol.Protocol {
|
||||||
|
return Proto
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Listener) Shutdown(ctx context.Context) error {
|
func (l *Listener) Shutdown(ctx context.Context) error {
|
||||||
if l.server == nil {
|
if l.server == nil {
|
||||||
return nil
|
return nil
|
||||||
|
@@ -6,12 +6,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener"
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListenerConfig is the configuration for the listener.
|
// ListenerConfig is the configuration for the listener.
|
||||||
@@ -26,8 +28,11 @@ type ListenerConfig struct {
|
|||||||
// It is the gate between the WebSocket listener and the Relay server logic.
|
// It is the gate between the WebSocket listener and the Relay server logic.
|
||||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
listenAddr string
|
||||||
|
|
||||||
relay *Relay
|
relay *Relay
|
||||||
listeners []listener.Listener
|
listeners []listener.Listener
|
||||||
|
listenerMux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates and returns a new relay server instance.
|
// NewServer creates and returns a new relay server instance.
|
||||||
@@ -57,10 +62,14 @@ func NewServer(config Config) (*Server, error) {
|
|||||||
|
|
||||||
// Listen starts the relay server.
|
// Listen starts the relay server.
|
||||||
func (r *Server) Listen(cfg ListenerConfig) error {
|
func (r *Server) Listen(cfg ListenerConfig) error {
|
||||||
|
r.listenAddr = cfg.Address
|
||||||
|
|
||||||
wSListener := &ws.Listener{
|
wSListener := &ws.Listener{
|
||||||
Address: cfg.Address,
|
Address: cfg.Address,
|
||||||
TLSConfig: cfg.TLSConfig,
|
TLSConfig: cfg.TLSConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.listenerMux.Lock()
|
||||||
r.listeners = append(r.listeners, wSListener)
|
r.listeners = append(r.listeners, wSListener)
|
||||||
|
|
||||||
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
|
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
|
||||||
@@ -85,6 +94,8 @@ func (r *Server) Listen(cfg ListenerConfig) error {
|
|||||||
}(l)
|
}(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.listenerMux.Unlock()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(errChan)
|
close(errChan)
|
||||||
var multiErr *multierror.Error
|
var multiErr *multierror.Error
|
||||||
@@ -100,12 +111,15 @@ func (r *Server) Listen(cfg ListenerConfig) error {
|
|||||||
func (r *Server) Shutdown(ctx context.Context) error {
|
func (r *Server) Shutdown(ctx context.Context) error {
|
||||||
r.relay.Shutdown(ctx)
|
r.relay.Shutdown(ctx)
|
||||||
|
|
||||||
|
r.listenerMux.Lock()
|
||||||
var multiErr *multierror.Error
|
var multiErr *multierror.Error
|
||||||
for _, l := range r.listeners {
|
for _, l := range r.listeners {
|
||||||
if err := l.Shutdown(ctx); err != nil {
|
if err := l.Shutdown(ctx); err != nil {
|
||||||
multiErr = multierror.Append(multiErr, err)
|
multiErr = multierror.Append(multiErr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
r.listeners = r.listeners[:0]
|
||||||
|
r.listenerMux.Unlock()
|
||||||
return nberrors.FormatErrorOrNil(multiErr)
|
return nberrors.FormatErrorOrNil(multiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,3 +127,18 @@ func (r *Server) Shutdown(ctx context.Context) error {
|
|||||||
func (r *Server) InstanceURL() string {
|
func (r *Server) InstanceURL() string {
|
||||||
return r.relay.instanceURL
|
return r.relay.instanceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Server) ListenerProtocols() []protocol.Protocol {
|
||||||
|
result := make([]protocol.Protocol, 0)
|
||||||
|
|
||||||
|
r.listenerMux.Lock()
|
||||||
|
for _, l := range r.listeners {
|
||||||
|
result = append(result, l.Protocol())
|
||||||
|
}
|
||||||
|
r.listenerMux.Unlock()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Server) ListenAddress() string {
|
||||||
|
return r.listenAddr
|
||||||
|
}
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
package tls
|
package tls
|
||||||
|
|
||||||
const nbalpn = "nb-quic"
|
const NBalpn = "nb-quic"
|
||||||
|
@@ -20,7 +20,7 @@ func ClientQUICTLSConfig() *tls.Config {
|
|||||||
|
|
||||||
return &tls.Config{
|
return &tls.Config{
|
||||||
InsecureSkipVerify: true, // Debug mode allows insecure connections
|
InsecureSkipVerify: true, // Debug mode allows insecure connections
|
||||||
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
|
NextProtos: []string{NBalpn}, // Ensure this matches the server's ALPN
|
||||||
RootCAs: certPool,
|
RootCAs: certPool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -19,7 +19,7 @@ func ClientQUICTLSConfig() *tls.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &tls.Config{
|
return &tls.Config{
|
||||||
NextProtos: []string{nbalpn},
|
NextProtos: []string{NBalpn},
|
||||||
RootCAs: certPool,
|
RootCAs: certPool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -23,7 +23,7 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := originTLSCfg.Clone()
|
cfg := originTLSCfg.Clone()
|
||||||
cfg.NextProtos = []string{nbalpn}
|
cfg.NextProtos = []string{NBalpn}
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,6 +74,6 @@ func generateTestTLSConfig() (*tls.Config, error) {
|
|||||||
|
|
||||||
return &tls.Config{
|
return &tls.Config{
|
||||||
Certificates: []tls.Certificate{tlsCert},
|
Certificates: []tls.Certificate{tlsCert},
|
||||||
NextProtos: []string{nbalpn},
|
NextProtos: []string{NBalpn},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@@ -12,6 +12,6 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
|||||||
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
|
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
|
||||||
}
|
}
|
||||||
cfg := originTLSCfg.Clone()
|
cfg := originTLSCfg.Clone()
|
||||||
cfg.NextProtos = []string{nbalpn}
|
cfg.NextProtos = []string{NBalpn}
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user