diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 67f3ef07b..58004dd4a 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -15,7 +15,6 @@ import ( "strconv" "strings" "sync" - "syscall" "time" "unicode" @@ -34,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -62,8 +62,25 @@ func main() { var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") + tmpDir := "/tmp" + if runtime.GOOS == "windows" { + tmpDir = os.TempDir() + } + + var saveLogsInFile bool + flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir)) + flag.Parse() + if saveLogsInFile { + logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) + err := util.InitLog("trace", logFile) + if err != nil { + log.Errorf("error while initializing log: %v", err) + return + } + } + a := app.NewWithID("NetBird") a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG)) @@ -76,8 +93,12 @@ func main() { if showSettings || showRoutes { a.Run() } else { - if err := checkPIDFile(); err != nil { - log.Errorf("check PID file: %v", err) + running, err := isAnotherProcessRunning() + if err != nil { + log.Errorf("error while checking process: %v", err) + } + if running { + log.Warn("another process is running") return } client.setDefaultFonts() @@ -861,19 +882,3 @@ func openURL(url string) error { } return err } - -// checkPIDFile exists and return error, or write new. -func checkPIDFile() error { - pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid") - if piddata, err := os.ReadFile(pidFile); err == nil { - if pid, err := strconv.Atoi(string(piddata)); err == nil { - if process, err := os.FindProcess(pid); err == nil { - if err := process.Signal(syscall.Signal(0)); err == nil { - return fmt.Errorf("process already exists: %d", pid) - } - } - } - } - - return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec -} diff --git a/client/ui/process.go b/client/ui/process.go new file mode 100644 index 000000000..bcb3dd879 --- /dev/null +++ b/client/ui/process.go @@ -0,0 +1,37 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + + "github.com/shirou/gopsutil/v3/process" +) + +func isAnotherProcessRunning() (bool, error) { + processes, err := process.Processes() + if err != nil { + return false, err + } + + pid := os.Getpid() + processName := strings.ToLower(filepath.Base(os.Args[0])) + + for _, p := range processes { + if int(p.Pid) == pid { + continue + } + + runningProcessPath, err := p.Exe() + // most errors are related to short-lived processes + if err != nil { + continue + } + + if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) { + return true, nil + } + } + + return false, nil +} diff --git a/client/ui/process_nonwindows.go b/client/ui/process_nonwindows.go new file mode 100644 index 000000000..0d17be2be --- /dev/null +++ b/client/ui/process_nonwindows.go @@ -0,0 +1,26 @@ +//go:build !windows + +package main + +import ( + "os" + + "github.com/shirou/gopsutil/v3/process" + log "github.com/sirupsen/logrus" +) + +func isProcessOwnedByCurrentUser(p *process.Process) bool { + currentUserID := os.Getuid() + uids, err := p.Uids() + if err != nil { + log.Errorf("get process uids: %v", err) + return false + } + for _, id := range uids { + log.Debugf("checking process uid: %d", id) + if int(id) == currentUserID { + return true + } + } + return false +} diff --git a/client/ui/process_windows.go b/client/ui/process_windows.go new file mode 100644 index 000000000..b15b0ed24 --- /dev/null +++ b/client/ui/process_windows.go @@ -0,0 +1,24 @@ +package main + +import ( + "os/user" + + "github.com/shirou/gopsutil/v3/process" + log "github.com/sirupsen/logrus" +) + +func isProcessOwnedByCurrentUser(p *process.Process) bool { + processUsername, err := p.Username() + if err != nil { + log.Errorf("get process username error: %v", err) + return false + } + + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user error: %v", err) + return false + } + + return processUsername == currUser.Username +}