Fix stop not cleaning up WireGuard interface (#286)

This commit is contained in:
Givi Khojanashvili 2022-03-25 16:21:04 +04:00 committed by GitHub
parent a15d52b263
commit 2aaeeac7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 101 additions and 125 deletions

View File

@ -35,10 +35,6 @@ var (
Short: "", Short: "",
Long: "", Long: "",
} }
// Execution control channel for stopCh signal
stopCh chan int
cleanupCh chan struct{}
) )
// Execute executes the root command. // Execute executes the root command.
@ -47,9 +43,6 @@ func Execute() error {
} }
func init() { func init() {
stopCh = make(chan int)
cleanupCh = make(chan struct{})
defaultConfigPath = "/etc/wiretrustee/config.json" defaultConfigPath = "/etc/wiretrustee/config.json"
defaultLogFile = "/var/log/wiretrustee/client.log" defaultLogFile = "/var/log/wiretrustee/client.log"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@ -79,14 +72,18 @@ func init() {
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success
func SetupCloseHandler() { func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
c := make(chan os.Signal, 1) termCh := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
for range c { done := ctx.Done()
log.Info("shutdown signal received") select {
stopCh <- 0 case <-done:
case <-termCh:
} }
log.Info("shutdown signal received")
cancel()
}() }()
} }

View File

@ -13,18 +13,13 @@ import (
type program struct { type program struct {
ctx context.Context ctx context.Context
cmd *cobra.Command cancel context.CancelFunc
args []string
serv *grpc.Server serv *grpc.Server
} }
func newProgram(cmd *cobra.Command, args []string) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
ctx := internal.CtxInitState(cmd.Context()) ctx = internal.CtxInitState(ctx)
return &program{ return &program{ctx: ctx, cancel: cancel}
ctx: ctx,
cmd: cmd,
args: args,
}
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
@ -48,4 +43,3 @@ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages wiretrustee service", Short: "manages wiretrustee service",
} }

View File

@ -1,6 +1,8 @@
package cmd package cmd
import ( import (
"context"
"fmt"
"net" "net"
"os" "os"
"strings" "strings"
@ -19,41 +21,40 @@ import (
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting service") //nolint log.Info("starting service") //nolint
go func() { // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer()
p.serv = grpc.NewServer()
split := strings.Split(daemonAddr, "://") split := strings.Split(daemonAddr, "://")
switch split[0] { switch split[0] {
case "unix": case "unix":
// cleanup failed close // cleanup failed close
stat, err := os.Stat(split[1]) stat, err := os.Stat(split[1])
if err == nil && !stat.IsDir() { if err == nil && !stat.IsDir() {
if err := os.Remove(split[1]); err != nil { if err := os.Remove(split[1]); err != nil {
log.Debugf("remove socket file: %v", err) log.Debugf("remove socket file: %v", err)
}
} }
case "tcp":
default:
log.Errorf("unsupported daemon address protocol: %v", split[0])
return
} }
case "tcp":
default:
return fmt.Errorf("unsupported daemon address protocol: %v", split[0])
}
listen, err := net.Listen(split[0], split[1]) listen, err := net.Listen(split[0], split[1])
if err != nil { if err != nil {
log.Fatalf("failed to listen daemon interface: %v", err) return fmt.Errorf("failed to listen daemon interface: %w", err)
} }
go func() {
defer listen.Close() defer listen.Close()
if split[0] == "unix" { if split[0] == "unix" {
err = os.Chmod(split[1], 0666) err = os.Chmod(split[1], 0o666)
if err != nil { if err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1]) log.Errorf("failed setting daemon permissions: %v", split[1])
return return
} }
} }
serverInstance := server.New(p.ctx, managementURL, configPath, stopCh, cleanupCh) serverInstance := server.New(p.ctx, managementURL, configPath)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed start daemon: %v", err) log.Fatalf("failed start daemon: %v", err)
} }
@ -67,21 +68,14 @@ func (p *program) Start(svc service.Service) error {
return nil return nil
} }
func (p *program) Stop(service.Service) error { func (p *program) Stop(srv service.Service) error {
go func() { p.cancel()
stopCh <- 1
}()
// stop CLI daemon service
if p.serv != nil { if p.serv != nil {
p.serv.GracefulStop() p.serv.Stop()
} }
select { time.Sleep(time.Second * 2)
case <-cleanupCh:
case <-time.After(time.Second * 10):
log.Warnf("failed waiting for service cleanup, terminating")
}
log.Info("stopped Wiretrustee service") //nolint log.Info("stopped Wiretrustee service") //nolint
return nil return nil
} }
@ -98,9 +92,10 @@ var runCmd = &cobra.Command{
return return
} }
SetupCloseHandler() ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
s, err := newSVC(newProgram(cmd, args), newSVCConfig()) s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return return
@ -125,7 +120,10 @@ var startCmd = &cobra.Command{
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
return err return err
} }
s, err := newSVC(newProgram(cmd, args), newSVCConfig())
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return err return err
@ -150,7 +148,10 @@ var stopCmd = &cobra.Command{
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
} }
s, err := newSVC(newProgram(cmd, args), newSVCConfig())
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return return
@ -174,7 +175,10 @@ var restartCmd = &cobra.Command{
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
} }
s, err := newSVC(newProgram(cmd, args), newSVCConfig())
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return return

View File

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"runtime" "runtime"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -28,7 +29,9 @@ var installCmd = &cobra.Command{
svcConfig.Dependencies = []string{"After=network.target syslog.target"} svcConfig.Dependencies = []string{"After=network.target syslog.target"}
} }
s, err := newSVC(newProgram(cmd, args), svcConfig) ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return err return err
@ -50,7 +53,9 @@ var uninstallCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars() SetFlagsFromEnvVars()
s, err := newSVC(newProgram(cmd, args), newSVCConfig()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {
cmd.PrintErrln(err) cmd.PrintErrln(err)
return return
@ -64,4 +69,3 @@ var uninstallCmd = &cobra.Command{
cmd.Println("Wiretrustee has been uninstalled") cmd.Println("Wiretrustee has been uninstalled")
}, },
} }

View File

@ -2,12 +2,13 @@ package cmd
import ( import (
"context" "context"
"github.com/wiretrustee/wiretrustee/util"
"net" "net"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/wiretrustee/wiretrustee/util"
clientProto "github.com/wiretrustee/wiretrustee/client/proto" clientProto "github.com/wiretrustee/wiretrustee/client/proto"
client "github.com/wiretrustee/wiretrustee/client/server" client "github.com/wiretrustee/wiretrustee/client/server"
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto" mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
@ -85,7 +86,6 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
func startClientDaemon( func startClientDaemon(
t *testing.T, ctx context.Context, managementURL, configPath string, t *testing.T, ctx context.Context, managementURL, configPath string,
stopCh chan int, cleanupCh chan<- struct{},
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -93,13 +93,7 @@ func startClientDaemon(
} }
s := grpc.NewServer() s := grpc.NewServer()
server := client.New( server := client.New(ctx, managementURL, configPath)
ctx,
managementURL,
configPath,
stopCh,
cleanupCh,
)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,6 +1,8 @@
package cmd package cmd
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wiretrustee/wiretrustee/util" "github.com/wiretrustee/wiretrustee/util"
@ -38,8 +40,10 @@ var upCmd = &cobra.Command{
return err return err
} }
SetupCloseHandler() var cancel context.CancelFunc
return internal.RunClient(ctx, config, stopCh, cleanupCh) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config)
} }
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@ -8,9 +8,7 @@ import (
"github.com/wiretrustee/wiretrustee/client/internal" "github.com/wiretrustee/wiretrustee/client/internal"
) )
var ( var cliAddr string
cliAddr string
)
func TestUpDaemon(t *testing.T) { func TestUpDaemon(t *testing.T) {
mgmAddr := startTestingServices(t) mgmAddr := startTestingServices(t)
@ -18,13 +16,10 @@ func TestUpDaemon(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
confPath := tempDir + "/config.json" confPath := tempDir + "/config.json"
stopCh = make(chan int, 1)
cleanupCh = make(chan struct{}, 1)
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
state := internal.CtxGetState(ctx) state := internal.CtxGetState(ctx)
_, cliLis := startClientDaemon(t, ctx, "http://"+mgmAddr, confPath, stopCh, cleanupCh) _, cliLis := startClientDaemon(t, ctx, "http://"+mgmAddr, confPath)
cliAddr = cliLis.Addr().String() cliAddr = cliLis.Addr().String()

View File

@ -17,9 +17,7 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient( func RunClient(ctx context.Context, config *Config) error {
ctx context.Context, config *Config, stopCh <-chan int, cleanupCh chan<- struct{},
) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
@ -90,10 +88,7 @@ func RunClient(
log.Print("Wiretrustee engine started, my IP is: ", peerConfig.Address) log.Print("Wiretrustee engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected) state.Set(StatusConnected)
select { <-ctx.Done()
case <-stopCh:
case <-ctx.Done():
}
backOff.Reset() backOff.Reset()
@ -114,10 +109,6 @@ func RunClient(
return wrapErr(err) return wrapErr(err)
} }
go func() {
cleanupCh <- struct{}{}
}()
log.Info("stopped Wiretrustee client") log.Info("stopped Wiretrustee client")
if _, err := state.Status(); err == ErrResetConnection { if _, err := state.Status(); err == ErrResetConnection {
@ -207,4 +198,3 @@ func connectToManagement(ctx context.Context, managementAddr string, ourPrivateK
return client, loginResp, nil return client, loginResp, nil
} }

View File

@ -18,8 +18,6 @@ type Server struct {
managementURL string managementURL string
configPath string configPath string
stopCh chan int
cleanupCh chan<- struct{}
mutex sync.Mutex mutex sync.Mutex
config *internal.Config config *internal.Config
@ -27,16 +25,11 @@ type Server struct {
} }
// New server instance constructor. // New server instance constructor.
func New( func New(ctx context.Context, managementURL, configPath string) *Server {
ctx context.Context, managementURL, configPath string,
stopCh chan int, cleanupCh chan<- struct{},
) *Server {
return &Server{ return &Server{
rootCtx: ctx, rootCtx: ctx,
managementURL: managementURL, managementURL: managementURL,
configPath: configPath, configPath: configPath,
stopCh: stopCh,
cleanupCh: cleanupCh,
} }
} }
@ -67,7 +60,7 @@ func (s *Server) Start() error {
s.config = config s.config = config
go func() { go func() {
if err := internal.RunClient(ctx, config, s.stopCh, s.cleanupCh); err != nil { if err := internal.RunClient(ctx, config); err != nil {
log.Errorf("init connections: %v", err) log.Errorf("init connections: %v", err)
} }
}() }()
@ -131,7 +124,7 @@ func (s *Server) Up(_ context.Context, msg *proto.UpRequest) (*proto.UpResponse,
} }
go func() { go func() {
if err := internal.RunClient(ctx, s.config, s.stopCh, s.cleanupCh); err != nil { if err := internal.RunClient(ctx, s.config); err != nil {
log.Errorf("run client connection: %v", state.Wrap(err)) log.Errorf("run client connection: %v", state.Wrap(err))
return return
} }

View File

@ -151,28 +151,31 @@ func (s *serviceClient) updateStatus() {
func (s *serviceClient) onTrayReady() { func (s *serviceClient) onTrayReady() {
systray.SetTemplateIcon(iconDisconnected, iconDisconnected) systray.SetTemplateIcon(iconDisconnected, iconDisconnected)
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected")
s.mStatus.Disable()
systray.AddSeparator()
s.mUp = systray.AddMenuItem("Up", "Up")
s.mDown = systray.AddMenuItem("Down", "Down")
s.mDown.Disable()
mURL := systray.AddMenuItem("Open UI", "wiretrustee website")
systray.AddSeparator()
mQuit := systray.AddMenuItem("Quit", "Quit the whole app")
go func() { go func() {
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected") for {
s.mStatus.Disable()
systray.AddSeparator()
s.mUp = systray.AddMenuItem("Up", "Up")
s.mDown = systray.AddMenuItem("Down", "Down")
s.mDown.Disable()
mURL := systray.AddMenuItem("Open UI", "wiretrustee website")
systray.AddSeparator()
mQuit := systray.AddMenuItem("Quit", "Quit the whole app")
s.updateStatus() s.updateStatus()
time.Sleep(time.Second * 3)
}
}()
ticker := time.NewTicker(time.Second * 3) go func() {
defer ticker.Stop()
var err error var err error
for { for {
select { select {
@ -191,8 +194,6 @@ func (s *serviceClient) onTrayReady() {
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
systray.Quit() systray.Quit()
return return
case <-ticker.C:
s.updateStatus()
} }
if err != nil { if err != nil {
log.Errorf("process connection: %v", err) log.Errorf("process connection: %v", err)