mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 10:21:10 +01:00
327 lines
8.6 KiB
Go
327 lines
8.6 KiB
Go
package cmd
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
|
"golang.org/x/crypto/acme/autocert"
|
|
|
|
"github.com/netbirdio/netbird/signal/metrics"
|
|
|
|
"github.com/netbirdio/netbird/encryption"
|
|
"github.com/netbirdio/netbird/signal/proto"
|
|
"github.com/netbirdio/netbird/signal/server"
|
|
"github.com/netbirdio/netbird/util"
|
|
"github.com/netbirdio/netbird/version"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/spf13/cobra"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/keepalive"
|
|
)
|
|
|
|
const (
|
|
metricsPort = 9090
|
|
)
|
|
|
|
var (
|
|
signalPort int
|
|
signalLetsencryptDomain string
|
|
signalSSLDir string
|
|
defaultSignalSSLDir string
|
|
tlsEnabled bool
|
|
|
|
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
|
MinTime: 5 * time.Second,
|
|
PermitWithoutStream: true,
|
|
})
|
|
|
|
signalKasp = grpc.KeepaliveParams(keepalive.ServerParameters{
|
|
MaxConnectionIdle: 15 * time.Second,
|
|
MaxConnectionAgeGrace: 5 * time.Second,
|
|
Time: 5 * time.Second,
|
|
Timeout: 2 * time.Second,
|
|
})
|
|
|
|
runCmd = &cobra.Command{
|
|
Use: "run",
|
|
Short: "start NetBird Signal Server daemon",
|
|
PreRun: func(cmd *cobra.Command, args []string) {
|
|
// detect whether user specified a port
|
|
userPort := cmd.Flag("port").Changed
|
|
if signalLetsencryptDomain != "" {
|
|
tlsEnabled = true
|
|
}
|
|
|
|
if !userPort {
|
|
// different defaults for signalPort
|
|
if tlsEnabled {
|
|
signalPort = 443
|
|
} else {
|
|
signalPort = 80
|
|
}
|
|
}
|
|
},
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
flag.Parse()
|
|
|
|
err := util.InitLog(logLevel, logFile)
|
|
if err != nil {
|
|
log.Fatalf("failed initializing log %v", err)
|
|
}
|
|
|
|
if signalSSLDir == "" {
|
|
oldPath := "/var/lib/wiretrustee"
|
|
if migrateToNetbird(oldPath, defaultSignalSSLDir) {
|
|
if err := cpDir(oldPath, defaultSignalSSLDir); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
var opts []grpc.ServerOption
|
|
var certManager *autocert.Manager
|
|
if tlsEnabled {
|
|
// Let's encrypt enabled -> generate certificate automatically
|
|
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
|
opts = append(opts, grpc.Creds(transportCredentials))
|
|
}
|
|
|
|
metricsServer := metrics.NewServer(metricsPort, "")
|
|
if err != nil {
|
|
return fmt.Errorf("setup metrics: %v", err)
|
|
}
|
|
|
|
opts = append(opts, signalKaep, signalKasp, grpc.StatsHandler(otelgrpc.NewServerHandler()))
|
|
grpcServer := grpc.NewServer(opts...)
|
|
|
|
go func() {
|
|
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
|
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
log.Fatalf("Failed to start metrics server: %v", err)
|
|
}
|
|
}()
|
|
|
|
srv, err := server.NewServer(metricsServer.Meter)
|
|
if err != nil {
|
|
return fmt.Errorf("creating signal server: %v", err)
|
|
}
|
|
proto.RegisterSignalExchangeServer(grpcServer, srv)
|
|
|
|
var compatListener net.Listener
|
|
if signalPort != 10000 {
|
|
// The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal
|
|
// are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000.
|
|
compatListener, err = serveGRPC(grpcServer, 10000)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
|
}
|
|
|
|
var grpcListener net.Listener
|
|
var httpListener net.Listener
|
|
if tlsEnabled {
|
|
httpListener = certManager.Listener()
|
|
if signalPort == 443 {
|
|
// running gRPC and HTTP cert manager on the same port
|
|
serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer)))
|
|
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
|
|
} else {
|
|
serveHTTP(httpListener, certManager.HTTPHandler(nil))
|
|
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
|
|
}
|
|
}
|
|
|
|
if signalPort != 443 || !tlsEnabled {
|
|
grpcListener, err = serveGRPC(grpcServer, signalPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
|
|
}
|
|
|
|
log.Infof("signal server version %s", version.NetbirdVersion())
|
|
log.Infof("started Signal Service")
|
|
|
|
SetupCloseHandler()
|
|
|
|
<-stopCh
|
|
if grpcListener != nil {
|
|
_ = grpcListener.Close()
|
|
log.Infof("stopped gRPC server")
|
|
}
|
|
if httpListener != nil {
|
|
_ = httpListener.Close()
|
|
log.Infof("stopped HTTP server")
|
|
}
|
|
if compatListener != nil {
|
|
_ = compatListener.Close()
|
|
log.Infof("stopped gRPC backward compatibility server")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
|
|
defer cancel()
|
|
if err := metricsServer.Shutdown(ctx); err != nil {
|
|
log.Errorf("Failed to stop metrics server: %v", err)
|
|
}
|
|
log.Infof("stopped metrics server")
|
|
|
|
log.Infof("stopped Signal Service")
|
|
|
|
return nil
|
|
},
|
|
}
|
|
)
|
|
|
|
func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
|
|
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")
|
|
if r.ProtoMajor == 2 && grpcHeader {
|
|
grpcServer.ServeHTTP(w, r)
|
|
}
|
|
})
|
|
}
|
|
|
|
func notifyStop(msg string) {
|
|
select {
|
|
case stopCh <- 1:
|
|
log.Error(msg)
|
|
default:
|
|
// stop has been already called, nothing to report
|
|
}
|
|
}
|
|
|
|
func serveHTTP(httpListener net.Listener, handler http.Handler) {
|
|
go func() {
|
|
err := http.Serve(httpListener, handler)
|
|
if err != nil {
|
|
notifyStop(fmt.Sprintf("failed running HTTP server %v", err))
|
|
}
|
|
}()
|
|
}
|
|
|
|
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go func() {
|
|
err := grpcServer.Serve(listener)
|
|
if err != nil {
|
|
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
|
|
}
|
|
}()
|
|
return listener, nil
|
|
}
|
|
|
|
func cpFile(src, dst string) error {
|
|
var err error
|
|
var srcfd *os.File
|
|
var dstfd *os.File
|
|
var srcinfo os.FileInfo
|
|
|
|
if srcfd, err = os.Open(src); err != nil {
|
|
return err
|
|
}
|
|
defer srcfd.Close()
|
|
|
|
if dstfd, err = os.Create(dst); err != nil {
|
|
return err
|
|
}
|
|
defer dstfd.Close()
|
|
|
|
if _, err = io.Copy(dstfd, srcfd); err != nil {
|
|
return err
|
|
}
|
|
if srcinfo, err = os.Stat(src); err != nil {
|
|
return err
|
|
}
|
|
return os.Chmod(dst, srcinfo.Mode())
|
|
}
|
|
|
|
func copySymLink(source, dest string) error {
|
|
link, err := os.Readlink(source)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.Symlink(link, dest)
|
|
}
|
|
|
|
func cpDir(src string, dst string) error {
|
|
var err error
|
|
var fds []os.DirEntry
|
|
var srcinfo os.FileInfo
|
|
|
|
if srcinfo, err = os.Stat(src); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil {
|
|
return err
|
|
}
|
|
|
|
if fds, err = os.ReadDir(src); err != nil {
|
|
return err
|
|
}
|
|
for _, fd := range fds {
|
|
srcfp := path.Join(src, fd.Name())
|
|
dstfp := path.Join(dst, fd.Name())
|
|
|
|
fileInfo, err := os.Stat(srcfp)
|
|
if err != nil {
|
|
log.Fatalf("Couldn't get fileInfo; %v", err)
|
|
}
|
|
|
|
switch fileInfo.Mode() & os.ModeType {
|
|
case os.ModeSymlink:
|
|
if err = copySymLink(srcfp, dstfp); err != nil {
|
|
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
|
}
|
|
case os.ModeDir:
|
|
if err = cpDir(srcfp, dstfp); err != nil {
|
|
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
|
}
|
|
default:
|
|
if err = cpFile(srcfp, dstfp); err != nil {
|
|
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func migrateToNetbird(oldPath, newPath string) bool {
|
|
_, errOld := os.Stat(oldPath)
|
|
_, errNew := os.Stat(newPath)
|
|
|
|
if errors.Is(errOld, fs.ErrNotExist) || errNew == nil {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func init() {
|
|
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
|
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
|
|
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
|
}
|