diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 999ec6536..feb59e7b9 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -17,7 +17,7 @@ import ( var ( port int - user = "netbird" + user = "root" host string ) @@ -90,7 +90,8 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) cmd.Printf("Couldn't connect. " + "You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + "Run the status command: \n\n" + - " netbird status\n\n") + " netbird status\n\n" + + "It might also be that the SSH server is disabled on the agent you are trying to connect to.\n") return nil } go func() { diff --git a/client/internal/engine.go b/client/internal/engine.go index cd969c0cf..7e115286d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( nbstatus "github.com/netbirdio/netbird/client/status" "math/rand" "net" + "reflect" "runtime" "strings" "sync" @@ -174,6 +175,13 @@ func (e *Engine) Stop() error { } } + if !isNil(e.sshServer) { + err := e.sshServer.Stop() + if err != nil { + log.Warnf("failed stopping the SSH server: %v", err) + } + } + log.Infof("stopped Netbird Engine") return nil @@ -301,7 +309,7 @@ func (e *Engine) removeAllPeers() error { func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) - if e.sshServer != nil { + if !isNil(e.sshServer) { e.sshServer.RemoveAuthorizedKey(peerKey) } @@ -434,6 +442,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return nil } +func isNil(server nbssh.Server) bool { + return server == nil || reflect.ValueOf(server).IsNil() +} + func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { if sshConf.GetSshEnabled() { if runtime.GOOS == "windows" { @@ -441,7 +453,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { return nil } // start SSH server if it wasn't running - if e.sshServer == nil { + if isNil(e.sshServer) { //nil sshServer means it has not yet been started var err error e.sshServer, err = e.sshServerFunc(e.config.SSHKey, @@ -466,7 +478,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } } else { // Disable SSH server request, so stop it if it was running - if e.sshServer != nil { + if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { log.Warnf("failed to stop SSH server %v", err) @@ -597,7 +609,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } // update SSHServer by adding remote peer SSH keys - if e.sshServer != nil { + if !isNil(e.sshServer) { for _, config := range networkMap.GetRemotePeers() { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey())) diff --git a/client/ssh/login.go b/client/ssh/login.go new file mode 100644 index 000000000..e6019578d --- /dev/null +++ b/client/ssh/login.go @@ -0,0 +1,36 @@ +package ssh + +import ( + "fmt" + "github.com/netbirdio/netbird/util" + "net" + "net/netip" + "os/exec" + "runtime" +) + +func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { + loginPath, err = exec.LookPath("login") + if err != nil { + return "", nil, err + } + + addrPort, err := netip.ParseAddrPort(remoteAddr.String()) + if err != nil { + return "", nil, err + } + + if runtime.GOOS == "linux" { + + if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") { + // detect if Arch Linux + return loginPath, []string{"-f", user, "-p"}, nil + } + + return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil + } else if runtime.GOOS == "darwin" { + return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil + } + + return "", nil, fmt.Errorf("unsupported platform") +} diff --git a/client/ssh/server.go b/client/ssh/server.go index 9180f0bf9..70f13e7c4 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -9,6 +9,9 @@ import ( "net" "os" "os/exec" + "os/user" + "runtime" + "strings" "sync" ) @@ -105,12 +108,20 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b return false } -func getShellType() string { - shell := os.Getenv("SHELL") - if shell == "" { - shell = "sh" +func prepareUserEnv(user *user.User, shell string) []string { + return []string{ + fmt.Sprintf("SHELL=" + shell), + fmt.Sprintf("USER=" + user.Username), + fmt.Sprintf("HOME=" + user.HomeDir), } - return shell +} + +func acceptEnv(s string) bool { + split := strings.Split(s, "=") + if len(split) != 2 { + return false + } + return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_") } // sessionHandler handles SSH session post auth @@ -118,14 +129,54 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { srv.mu.Lock() srv.sessions = append(srv.sessions, session) srv.mu.Unlock() + + defer func() { + err := session.Close() + if err != nil { + return + } + }() + + localUser, err := user.Lookup(session.User()) + if err != nil { + _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint + err = session.Exit(1) + if err != nil { + return + } + log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User()) + return + } + ptyReq, winCh, isPty := session.Pty() if isPty { - cmd := exec.Command(getShellType()) - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%session", ptyReq.Term)) + loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr()) + if err != nil { + log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String()) + return + } + cmd := exec.Command(loginCmd, loginArgs...) + go func() { + <-session.Context().Done() + err := cmd.Process.Kill() + if err != nil { + return + } + }() + cmd.Dir = localUser.HomeDir + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...) + for _, v := range session.Environ() { + if acceptEnv(v) { + cmd.Env = append(cmd.Env, v) + } + } + file, err := pty.Start(cmd) if err != nil { log.Errorf("failed starting SSH server %v", err) } + go func() { for win := range winCh { setWinSize(file, win.Width, win.Height) @@ -181,3 +232,19 @@ func (srv *DefaultServer) Start() error { return nil } + +func getUserShell(userID string) string { + if runtime.GOOS == "linux" { + output, _ := exec.Command("getent", "passwd", userID).Output() + line := strings.SplitN(string(output), ":", 10) + if len(line) > 6 { + return strings.TrimSpace(line[6]) + } + } + + shell := os.Getenv("SHELL") + if shell == "" { + shell = "/bin/sh" + } + return shell +} diff --git a/util/common.go b/util/common.go index 4d0059ed9..22c975636 100644 --- a/util/common.go +++ b/util/common.go @@ -1,5 +1,7 @@ package util +import "os" + // SliceDiff returns the elements in slice `x` that are not in slice `y` func SliceDiff(x, y []string) []string { mapY := make(map[string]struct{}, len(y)) @@ -14,3 +16,9 @@ func SliceDiff(x, y []string) []string { } return diff } + +// FileExists returns true if specified file exists +func FileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +}