diff --git a/relay/cmd/root.go b/relay/cmd/root.go index c662dfbb7..eb2cdebf8 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "os/signal" + "sync" "syscall" "time" @@ -17,8 +18,9 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/shared/relay/auth" + "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/util" ) @@ -34,12 +36,13 @@ type Config struct { LetsencryptDomains []string // in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or // in the AWS credentials file - LetsencryptAWSRoute53 bool - TlsCertFile string - TlsKeyFile string - AuthSecret string - LogLevel string - LogFile string + LetsencryptAWSRoute53 bool + TlsCertFile string + TlsKeyFile string + AuthSecret string + LogLevel string + LogFile string + HealthcheckListenAddress string } func (c Config) Validate() error { @@ -87,6 +90,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server") setFlagsFromEnvVars(rootCmd) } @@ -102,6 +106,7 @@ func waitForExitSignal() { } func execute(cmd *cobra.Command, args []string) error { + wg := sync.WaitGroup{} err := cobraConfig.Validate() if err != nil { log.Debugf("invalid config: %s", err) @@ -120,7 +125,9 @@ func execute(cmd *cobra.Command, args []string) error { return fmt.Errorf("setup metrics: %v", err) } + wg.Add(1) go func() { + defer wg.Done() log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start metrics server: %v", err) @@ -154,12 +161,31 @@ func execute(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create relay server: %v", err) } log.Infof("server will be available on: %s", srv.InstanceURL()) + wg.Add(1) go func() { + defer wg.Done() if err := srv.Listen(srvListenerCfg); err != nil { log.Fatalf("failed to bind server: %s", err) } }() + hCfg := healthcheck.Config{ + ListenAddress: cobraConfig.HealthcheckListenAddress, + ServiceChecker: srv, + } + httpHealthcheck, err := healthcheck.NewServer(hCfg) + if err != nil { + log.Debugf("failed to create healthcheck server: %v", err) + return fmt.Errorf("failed to create healthcheck server: %v", err) + } + wg.Add(1) + go func() { + defer wg.Done() + if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start healthcheck server: %v", err) + } + }() + // it will block until exit signal waitForExitSignal() @@ -167,6 +193,10 @@ func execute(cmd *cobra.Command, args []string) error { defer cancel() var shutDownErrors error + if err := httpHealthcheck.Shutdown(ctx); err != nil { + shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err)) + } + if err := srv.Shutdown(ctx); err != nil { shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err)) } @@ -175,6 +205,8 @@ func execute(cmd *cobra.Command, args []string) error { if err := metricsServer.Shutdown(ctx); err != nil { shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err)) } + + wg.Wait() return shutDownErrors } diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go new file mode 100644 index 000000000..eedd62394 --- /dev/null +++ b/relay/healthcheck/healthcheck.go @@ -0,0 +1,195 @@ +package healthcheck + +import ( + "context" + "encoding/json" + "errors" + "net" + "net/http" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server/listener/quic" + "github.com/netbirdio/netbird/relay/server/listener/ws" +) + +const ( + statusHealthy = "healthy" + statusUnhealthy = "unhealthy" + + path = "/health" + + cacheTTL = 3 * time.Second // Cache TTL for health status +) + +type ServiceChecker interface { + ListenerProtocols() []protocol.Protocol + ListenAddress() string +} + +type HealthStatus struct { + Status string `json:"status"` + Timestamp time.Time `json:"timestamp"` + Listeners []protocol.Protocol `json:"listeners"` + CertificateValid bool `json:"certificate_valid"` +} + +type Config struct { + ListenAddress string + ServiceChecker ServiceChecker +} + +type Server struct { + config Config + httpServer *http.Server + + cacheMu sync.Mutex + cacheStatus *HealthStatus +} + +func NewServer(config Config) (*Server, error) { + mux := http.NewServeMux() + + if config.ServiceChecker == nil { + return nil, errors.New("service checker is required") + } + + server := &Server{ + config: config, + httpServer: &http.Server{ + Addr: config.ListenAddress, + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 15 * time.Second, + }, + } + + mux.HandleFunc(path, server.handleHealthcheck) + return server, nil +} + +func (s *Server) ListenAndServe() error { + log.Infof("starting healthcheck server on: http://%s%s", dialAddress(s.config.ListenAddress), path) + return s.httpServer.ListenAndServe() +} + +// Shutdown gracefully shuts down the healthcheck server +func (s *Server) Shutdown(ctx context.Context) error { + log.Info("Shutting down healthcheck server") + return s.httpServer.Shutdown(ctx) +} + +func (s *Server) handleHealthcheck(w http.ResponseWriter, _ *http.Request) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var ( + status *HealthStatus + ok bool + ) + // Cache check + s.cacheMu.Lock() + status = s.cacheStatus + s.cacheMu.Unlock() + + if status != nil && time.Since(status.Timestamp) <= cacheTTL { + ok = status.Status == statusHealthy + } else { + status, ok = s.getHealthStatus(ctx) + // Update cache + s.cacheMu.Lock() + s.cacheStatus = status + s.cacheMu.Unlock() + } + + w.Header().Set("Content-Type", "application/json") + + if ok { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } + + encoder := json.NewEncoder(w) + if err := encoder.Encode(status); err != nil { + log.Errorf("Failed to encode healthcheck response: %v", err) + } +} + +func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { + healthy := true + status := &HealthStatus{ + Timestamp: time.Now(), + Status: statusHealthy, + CertificateValid: true, + } + + listeners, ok := s.validateListeners() + if !ok { + status.Status = statusUnhealthy + healthy = false + } + status.Listeners = listeners + + if ok := s.validateCertificate(ctx); !ok { + status.Status = statusUnhealthy + status.CertificateValid = false + healthy = false + } + + return status, healthy +} + +func (s *Server) validateListeners() ([]protocol.Protocol, bool) { + listeners := s.config.ServiceChecker.ListenerProtocols() + if len(listeners) == 0 { + return nil, false + } + return listeners, true +} + +func (s *Server) validateCertificate(ctx context.Context) bool { + listenAddress := s.config.ServiceChecker.ListenAddress() + if listenAddress == "" { + log.Warn("listen address is empty") + return false + } + + dAddr := dialAddress(listenAddress) + + for _, proto := range s.config.ServiceChecker.ListenerProtocols() { + switch proto { + case ws.Proto: + if err := dialWS(ctx, dAddr); err != nil { + log.Errorf("failed to dial WebSocket listener: %v", err) + return false + } + case quic.Proto: + if err := dialQUIC(ctx, dAddr); err != nil { + log.Errorf("failed to dial QUIC listener: %v", err) + return false + } + default: + log.Warnf("unknown protocol for healthcheck: %s", proto) + return false + } + } + return true +} + +func dialAddress(listenAddress string) string { + host, port, err := net.SplitHostPort(listenAddress) + if err != nil { + return listenAddress // fallback, might be invalid for dialing + } + + if host == "" || host == "::" || host == "0.0.0.0" { + host = "0.0.0.0" + } + + return net.JoinHostPort(host, port) +} diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go new file mode 100644 index 000000000..1582edf7b --- /dev/null +++ b/relay/healthcheck/quic.go @@ -0,0 +1,31 @@ +package healthcheck + +import ( + "context" + "crypto/tls" + "fmt" + "time" + + "github.com/quic-go/quic-go" + + tlsnb "github.com/netbirdio/netbird/shared/relay/tls" +) + +func dialQUIC(ctx context.Context, address string) error { + tlsConfig := &tls.Config{ + InsecureSkipVerify: false, // Keep certificate validation enabled + NextProtos: []string{tlsnb.NBalpn}, + } + + conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{ + MaxIdleTimeout: 30 * time.Second, + KeepAlivePeriod: 10 * time.Second, + EnableDatagrams: true, + }) + if err != nil { + return fmt.Errorf("failed to connect to QUIC server: %w", err) + } + + _ = conn.CloseWithError(0, "availability check complete") + return nil +} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go new file mode 100644 index 000000000..49694356c --- /dev/null +++ b/relay/healthcheck/ws.go @@ -0,0 +1,28 @@ +package healthcheck + +import ( + "context" + "fmt" + + "github.com/coder/websocket" + + "github.com/netbirdio/netbird/shared/relay" +) + +func dialWS(ctx context.Context, address string) error { + url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath) + + conn, resp, err := websocket.Dial(ctx, url, nil) + if resp != nil { + defer func() { + _ = resp.Body.Close() + }() + + } + if err != nil { + return fmt.Errorf("failed to connect to websocket: %w", err) + } + + _ = conn.Close(websocket.StatusNormalClosure, "availability check complete") + return nil +} diff --git a/relay/protocol/protocol.go b/relay/protocol/protocol.go new file mode 100644 index 000000000..0d43b92e1 --- /dev/null +++ b/relay/protocol/protocol.go @@ -0,0 +1,3 @@ +package protocol + +type Protocol string diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go index 535c8bcd9..0a79182f4 100644 --- a/relay/server/listener/listener.go +++ b/relay/server/listener/listener.go @@ -3,9 +3,12 @@ package listener import ( "context" "net" + + "github.com/netbirdio/netbird/relay/protocol" ) type Listener interface { Listen(func(conn net.Conn)) error Shutdown(ctx context.Context) error + Protocol() protocol.Protocol } diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 2a4a668f0..d3160a44e 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -9,8 +9,12 @@ import ( "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/protocol" ) +const Proto protocol.Protocol = "quic" + type Listener struct { // Address is the address to listen on Address string @@ -50,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { } } +func (l *Listener) Protocol() protocol.Protocol { + return Proto +} + func (l *Listener) Shutdown(ctx context.Context) error { if l.listener == nil { return nil diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 8579fb137..332127660 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -11,11 +11,14 @@ import ( "github.com/coder/websocket" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/shared/relay" ) -// URLPath is the path for the websocket connection. -const URLPath = relay.WebSocketURLPath +const ( + Proto protocol.Protocol = "ws" + URLPath = relay.WebSocketURLPath +) type Listener struct { // Address is the address to listen on. @@ -51,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { return err } +func (l *Listener) Protocol() protocol.Protocol { + return Proto +} + func (l *Listener) Shutdown(ctx context.Context) error { if l.server == nil { return nil diff --git a/relay/server/server.go b/relay/server/server.go index 59695e8a9..4c30e7fdc 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -6,12 +6,14 @@ import ( "sync" "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/ws" quictls "github.com/netbirdio/netbird/shared/relay/tls" - log "github.com/sirupsen/logrus" ) // ListenerConfig is the configuration for the listener. @@ -26,8 +28,11 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - relay *Relay - listeners []listener.Listener + listenAddr string + + relay *Relay + listeners []listener.Listener + listenerMux sync.Mutex } // NewServer creates and returns a new relay server instance. @@ -57,10 +62,14 @@ func NewServer(config Config) (*Server, error) { // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { + r.listenAddr = cfg.Address + wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, } + + r.listenerMux.Lock() r.listeners = append(r.listeners, wSListener) tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig) @@ -85,6 +94,8 @@ func (r *Server) Listen(cfg ListenerConfig) error { }(l) } + r.listenerMux.Unlock() + wg.Wait() close(errChan) var multiErr *multierror.Error @@ -100,12 +111,15 @@ func (r *Server) Listen(cfg ListenerConfig) error { func (r *Server) Shutdown(ctx context.Context) error { r.relay.Shutdown(ctx) + r.listenerMux.Lock() var multiErr *multierror.Error for _, l := range r.listeners { if err := l.Shutdown(ctx); err != nil { multiErr = multierror.Append(multiErr, err) } } + r.listeners = r.listeners[:0] + r.listenerMux.Unlock() return nberrors.FormatErrorOrNil(multiErr) } @@ -113,3 +127,18 @@ func (r *Server) Shutdown(ctx context.Context) error { func (r *Server) InstanceURL() string { return r.relay.instanceURL } + +func (r *Server) ListenerProtocols() []protocol.Protocol { + result := make([]protocol.Protocol, 0) + + r.listenerMux.Lock() + for _, l := range r.listeners { + result = append(result, l.Protocol()) + } + r.listenerMux.Unlock() + return result +} + +func (r *Server) ListenAddress() string { + return r.listenAddr +} diff --git a/shared/relay/tls/alpn.go b/shared/relay/tls/alpn.go index 29497d401..484897ad3 100644 --- a/shared/relay/tls/alpn.go +++ b/shared/relay/tls/alpn.go @@ -1,3 +1,3 @@ package tls -const nbalpn = "nb-quic" +const NBalpn = "nb-quic" diff --git a/shared/relay/tls/client_dev.go b/shared/relay/tls/client_dev.go index 52e5535c5..033802ac7 100644 --- a/shared/relay/tls/client_dev.go +++ b/shared/relay/tls/client_dev.go @@ -20,7 +20,7 @@ func ClientQUICTLSConfig() *tls.Config { return &tls.Config{ InsecureSkipVerify: true, // Debug mode allows insecure connections - NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN + NextProtos: []string{NBalpn}, // Ensure this matches the server's ALPN RootCAs: certPool, } } diff --git a/shared/relay/tls/client_prod.go b/shared/relay/tls/client_prod.go index 62e218bc3..d1f1842d2 100644 --- a/shared/relay/tls/client_prod.go +++ b/shared/relay/tls/client_prod.go @@ -19,7 +19,7 @@ func ClientQUICTLSConfig() *tls.Config { } return &tls.Config{ - NextProtos: []string{nbalpn}, + NextProtos: []string{NBalpn}, RootCAs: certPool, } } diff --git a/shared/relay/tls/server_dev.go b/shared/relay/tls/server_dev.go index 1a01658fc..6837cfb9a 100644 --- a/shared/relay/tls/server_dev.go +++ b/shared/relay/tls/server_dev.go @@ -23,7 +23,7 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { } cfg := originTLSCfg.Clone() - cfg.NextProtos = []string{nbalpn} + cfg.NextProtos = []string{NBalpn} return cfg, nil } @@ -74,6 +74,6 @@ func generateTestTLSConfig() (*tls.Config, error) { return &tls.Config{ Certificates: []tls.Certificate{tlsCert}, - NextProtos: []string{nbalpn}, + NextProtos: []string{NBalpn}, }, nil } diff --git a/shared/relay/tls/server_prod.go b/shared/relay/tls/server_prod.go index 9d1c47d88..b29918fb9 100644 --- a/shared/relay/tls/server_prod.go +++ b/shared/relay/tls/server_prod.go @@ -12,6 +12,6 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { return nil, fmt.Errorf("valid TLS config is required for QUIC listener") } cfg := originTLSCfg.Clone() - cfg.NextProtos = []string{nbalpn} + cfg.NextProtos = []string{NBalpn} return cfg, nil }