Load user profile when SSH (#380)

This PR fixes issues with the terminal when
running netbird ssh to a remote agent.
Every session looks up a user and loads its
profile. If no user is found, the connection is rejected.
The default user is root.
This commit is contained in:
Misha Bragin 2022-07-07 11:24:38 +02:00 committed by GitHub
parent 49e9113e0f
commit d4a3ee9d87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 13 deletions

View File

@ -17,7 +17,7 @@ import (
var ( var (
port int port int
user = "netbird" user = "root"
host string host string
) )
@ -90,7 +90,8 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
cmd.Printf("Couldn't connect. " + cmd.Printf("Couldn't connect. " +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + "You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
"Run the status command: \n\n" + "Run the status command: \n\n" +
" netbird status\n\n") " netbird status\n\n" +
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil return nil
} }
go func() { go func() {

View File

@ -7,6 +7,7 @@ import (
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
"math/rand" "math/rand"
"net" "net"
"reflect"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -174,6 +175,13 @@ func (e *Engine) Stop() error {
} }
} }
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
return nil return nil
@ -301,7 +309,7 @@ func (e *Engine) removeAllPeers() error {
func (e *Engine) removePeer(peerKey string) error { func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey) log.Debugf("removing peer from engine %s", peerKey)
if e.sshServer != nil { if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey) e.sshServer.RemoveAuthorizedKey(peerKey)
} }
@ -434,6 +442,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
func isNil(server nbssh.Server) bool {
return server == nil || reflect.ValueOf(server).IsNil()
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if sshConf.GetSshEnabled() { if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@ -441,7 +453,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
return nil return nil
} }
// start SSH server if it wasn't running // start SSH server if it wasn't running
if e.sshServer == nil { if isNil(e.sshServer) {
//nil sshServer means it has not yet been started //nil sshServer means it has not yet been started
var err error var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
@ -466,7 +478,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
} else { } else {
// Disable SSH server request, so stop it if it was running // Disable SSH server request, so stop it if it was running
if e.sshServer != nil { if !isNil(e.sshServer) {
err := e.sshServer.Stop() err := e.sshServer.Stop()
if err != nil { if err != nil {
log.Warnf("failed to stop SSH server %v", err) log.Warnf("failed to stop SSH server %v", err)
@ -597,7 +609,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
// update SSHServer by adding remote peer SSH keys // update SSHServer by adding remote peer SSH keys
if e.sshServer != nil { if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() { for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey())) err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))

36
client/ssh/login.go Normal file
View File

@ -0,0 +1,36 @@
package ssh
import (
"fmt"
"github.com/netbirdio/netbird/util"
"net"
"net/netip"
"os/exec"
"runtime"
)
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
loginPath, err = exec.LookPath("login")
if err != nil {
return "", nil, err
}
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
if err != nil {
return "", nil, err
}
if runtime.GOOS == "linux" {
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
// detect if Arch Linux
return loginPath, []string{"-f", user, "-p"}, nil
}
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
} else if runtime.GOOS == "darwin" {
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
}
return "", nil, fmt.Errorf("unsupported platform")
}

View File

@ -9,6 +9,9 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"os/user"
"runtime"
"strings"
"sync" "sync"
) )
@ -105,12 +108,20 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b
return false return false
} }
func getShellType() string { func prepareUserEnv(user *user.User, shell string) []string {
shell := os.Getenv("SHELL") return []string{
if shell == "" { fmt.Sprintf("SHELL=" + shell),
shell = "sh" fmt.Sprintf("USER=" + user.Username),
fmt.Sprintf("HOME=" + user.HomeDir),
} }
return shell }
func acceptEnv(s string) bool {
split := strings.Split(s, "=")
if len(split) != 2 {
return false
}
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
} }
// sessionHandler handles SSH session post auth // sessionHandler handles SSH session post auth
@ -118,14 +129,54 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
srv.mu.Lock() srv.mu.Lock()
srv.sessions = append(srv.sessions, session) srv.sessions = append(srv.sessions, session)
srv.mu.Unlock() srv.mu.Unlock()
defer func() {
err := session.Close()
if err != nil {
return
}
}()
localUser, err := user.Lookup(session.User())
if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
err = session.Exit(1)
if err != nil {
return
}
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
return
}
ptyReq, winCh, isPty := session.Pty() ptyReq, winCh, isPty := session.Pty()
if isPty { if isPty {
cmd := exec.Command(getShellType()) loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%session", ptyReq.Term)) if err != nil {
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
return
}
cmd := exec.Command(loginCmd, loginArgs...)
go func() {
<-session.Context().Done()
err := cmd.Process.Kill()
if err != nil {
return
}
}()
cmd.Dir = localUser.HomeDir
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
for _, v := range session.Environ() {
if acceptEnv(v) {
cmd.Env = append(cmd.Env, v)
}
}
file, err := pty.Start(cmd) file, err := pty.Start(cmd)
if err != nil { if err != nil {
log.Errorf("failed starting SSH server %v", err) log.Errorf("failed starting SSH server %v", err)
} }
go func() { go func() {
for win := range winCh { for win := range winCh {
setWinSize(file, win.Width, win.Height) setWinSize(file, win.Width, win.Height)
@ -181,3 +232,19 @@ func (srv *DefaultServer) Start() error {
return nil return nil
} }
func getUserShell(userID string) string {
if runtime.GOOS == "linux" {
output, _ := exec.Command("getent", "passwd", userID).Output()
line := strings.SplitN(string(output), ":", 10)
if len(line) > 6 {
return strings.TrimSpace(line[6])
}
}
shell := os.Getenv("SHELL")
if shell == "" {
shell = "/bin/sh"
}
return shell
}

View File

@ -1,5 +1,7 @@
package util package util
import "os"
// SliceDiff returns the elements in slice `x` that are not in slice `y` // SliceDiff returns the elements in slice `x` that are not in slice `y`
func SliceDiff(x, y []string) []string { func SliceDiff(x, y []string) []string {
mapY := make(map[string]struct{}, len(y)) mapY := make(map[string]struct{}, len(y))
@ -14,3 +16,9 @@ func SliceDiff(x, y []string) []string {
} }
return diff return diff
} }
// FileExists returns true if specified file exists
func FileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}