package cmd import ( "context" "errors" "flag" "fmt" "io" "io/fs" "net" "net/http" "os" "path" "strings" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/crypto/acme/autocert" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" ) const ( metricsPort = 9090 ) var ( signalPort int signalLetsencryptDomain string signalSSLDir string defaultSignalSSLDir string tlsEnabled bool signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, PermitWithoutStream: true, }) signalKasp = grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionIdle: 15 * time.Second, MaxConnectionAgeGrace: 5 * time.Second, Time: 5 * time.Second, Timeout: 2 * time.Second, }) runCmd = &cobra.Command{ Use: "run", Short: "start NetBird Signal Server daemon", PreRun: func(cmd *cobra.Command, args []string) { // detect whether user specified a port userPort := cmd.Flag("port").Changed if signalLetsencryptDomain != "" { tlsEnabled = true } if !userPort { // different defaults for signalPort if tlsEnabled { signalPort = 443 } else { signalPort = 80 } } }, RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() err := util.InitLog(logLevel, logFile) if err != nil { log.Fatalf("failed initializing log %v", err) } if signalSSLDir == "" { oldPath := "/var/lib/wiretrustee" if migrateToNetbird(oldPath, defaultSignalSSLDir) { if err := cpDir(oldPath, defaultSignalSSLDir); err != nil { log.Fatal(err) } } } var opts []grpc.ServerOption var certManager *autocert.Manager if tlsEnabled { // Let's encrypt enabled -> generate certificate automatically certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { return err } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) opts = append(opts, grpc.Creds(transportCredentials)) } metricsServer := metrics.NewServer(metricsPort, "") if err != nil { return fmt.Errorf("setup metrics: %v", err) } opts = append(opts, signalKaep, signalKasp, grpc.StatsHandler(otelgrpc.NewServerHandler())) grpcServer := grpc.NewServer(opts...) go func() { log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start metrics server: %v", err) } }() srv, err := server.NewServer(metricsServer.Meter) if err != nil { return fmt.Errorf("creating signal server: %v", err) } proto.RegisterSignalExchangeServer(grpcServer, srv) var compatListener net.Listener if signalPort != 10000 { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. compatListener, err = serveGRPC(grpcServer, 10000) if err != nil { return err } log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } var grpcListener net.Listener var httpListener net.Listener if tlsEnabled { httpListener = certManager.Listener() if signalPort == 443 { // running gRPC and HTTP cert manager on the same port serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer))) log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) } else { serveHTTP(httpListener, certManager.HTTPHandler(nil)) log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) } } if signalPort != 443 || !tlsEnabled { grpcListener, err = serveGRPC(grpcServer, signalPort) if err != nil { return err } log.Infof("running gRPC server: %s", grpcListener.Addr().String()) } log.Infof("signal server version %s", version.NetbirdVersion()) log.Infof("started Signal Service") SetupCloseHandler() <-stopCh if grpcListener != nil { _ = grpcListener.Close() log.Infof("stopped gRPC server") } if httpListener != nil { _ = httpListener.Close() log.Infof("stopped HTTP server") } if compatListener != nil { _ = compatListener.Close() log.Infof("stopped gRPC backward compatibility server") } ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second) defer cancel() if err := metricsServer.Shutdown(ctx); err != nil { log.Errorf("Failed to stop metrics server: %v", err) } log.Infof("stopped metrics server") log.Infof("stopped Signal Service") return nil }, } ) func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto") if r.ProtoMajor == 2 && grpcHeader { grpcServer.ServeHTTP(w, r) } }) } func notifyStop(msg string) { select { case stopCh <- 1: log.Error(msg) default: // stop has been already called, nothing to report } } func serveHTTP(httpListener net.Listener, handler http.Handler) { go func() { err := http.Serve(httpListener, handler) if err != nil { notifyStop(fmt.Sprintf("failed running HTTP server %v", err)) } }() } func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return nil, err } go func() { err := grpcServer.Serve(listener) if err != nil { notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) } }() return listener, nil } func cpFile(src, dst string) error { var err error var srcfd *os.File var dstfd *os.File var srcinfo os.FileInfo if srcfd, err = os.Open(src); err != nil { return err } defer srcfd.Close() if dstfd, err = os.Create(dst); err != nil { return err } defer dstfd.Close() if _, err = io.Copy(dstfd, srcfd); err != nil { return err } if srcinfo, err = os.Stat(src); err != nil { return err } return os.Chmod(dst, srcinfo.Mode()) } func copySymLink(source, dest string) error { link, err := os.Readlink(source) if err != nil { return err } return os.Symlink(link, dest) } func cpDir(src string, dst string) error { var err error var fds []os.DirEntry var srcinfo os.FileInfo if srcinfo, err = os.Stat(src); err != nil { return err } if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { return err } if fds, err = os.ReadDir(src); err != nil { return err } for _, fd := range fds { srcfp := path.Join(src, fd.Name()) dstfp := path.Join(dst, fd.Name()) fileInfo, err := os.Stat(srcfp) if err != nil { log.Fatalf("Couldn't get fileInfo; %v", err) } switch fileInfo.Mode() & os.ModeType { case os.ModeSymlink: if err = copySymLink(srcfp, dstfp); err != nil { log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) } case os.ModeDir: if err = cpDir(srcfp, dstfp); err != nil { log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) } default: if err = cpFile(srcfp, dstfp); err != nil { log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) } } } return nil } func migrateToNetbird(oldPath, newPath string) bool { _, errOld := os.Stat(oldPath) _, errNew := os.Stat(newPath) if errors.Is(errOld, fs.ErrNotExist) || errNew == nil { return false } return true } func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") }