mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-28 11:03:30 +01:00
Load user profile when SSH (#380)
This PR fixes issues with the terminal when running netbird ssh to a remote agent. Every session looks up a user and loads its profile. If no user is found, the connection is rejected. The default user is root.
This commit is contained in:
parent
49e9113e0f
commit
d4a3ee9d87
@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
port int
|
port int
|
||||||
user = "netbird"
|
user = "root"
|
||||||
host string
|
host string
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +90,8 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
cmd.Printf("Couldn't connect. " +
|
cmd.Printf("Couldn't connect. " +
|
||||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||||
"Run the status command: \n\n" +
|
"Run the status command: \n\n" +
|
||||||
" netbird status\n\n")
|
" netbird status\n\n" +
|
||||||
|
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -174,6 +175,13 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !isNil(e.sshServer) {
|
||||||
|
err := e.sshServer.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed stopping the SSH server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -301,7 +309,7 @@ func (e *Engine) removeAllPeers() error {
|
|||||||
func (e *Engine) removePeer(peerKey string) error {
|
func (e *Engine) removePeer(peerKey string) error {
|
||||||
log.Debugf("removing peer from engine %s", peerKey)
|
log.Debugf("removing peer from engine %s", peerKey)
|
||||||
|
|
||||||
if e.sshServer != nil {
|
if !isNil(e.sshServer) {
|
||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -434,6 +442,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNil(server nbssh.Server) bool {
|
||||||
|
return server == nil || reflect.ValueOf(server).IsNil()
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
if sshConf.GetSshEnabled() {
|
if sshConf.GetSshEnabled() {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@ -441,7 +453,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// start SSH server if it wasn't running
|
// start SSH server if it wasn't running
|
||||||
if e.sshServer == nil {
|
if isNil(e.sshServer) {
|
||||||
//nil sshServer means it has not yet been started
|
//nil sshServer means it has not yet been started
|
||||||
var err error
|
var err error
|
||||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
||||||
@ -466,7 +478,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Disable SSH server request, so stop it if it was running
|
// Disable SSH server request, so stop it if it was running
|
||||||
if e.sshServer != nil {
|
if !isNil(e.sshServer) {
|
||||||
err := e.sshServer.Stop()
|
err := e.sshServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to stop SSH server %v", err)
|
log.Warnf("failed to stop SSH server %v", err)
|
||||||
@ -597,7 +609,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
// update SSHServer by adding remote peer SSH keys
|
||||||
if e.sshServer != nil {
|
if !isNil(e.sshServer) {
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
for _, config := range networkMap.GetRemotePeers() {
|
||||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
||||||
|
36
client/ssh/login.go
Normal file
36
client/ssh/login.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
|
||||||
|
loginPath, err = exec.LookPath("login")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
|
||||||
|
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
|
||||||
|
// detect if Arch Linux
|
||||||
|
return loginPath, []string{"-f", user, "-p"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||||
|
} else if runtime.GOOS == "darwin" {
|
||||||
|
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil, fmt.Errorf("unsupported platform")
|
||||||
|
}
|
@ -9,6 +9,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,12 +108,20 @@ func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func getShellType() string {
|
func prepareUserEnv(user *user.User, shell string) []string {
|
||||||
shell := os.Getenv("SHELL")
|
return []string{
|
||||||
if shell == "" {
|
fmt.Sprintf("SHELL=" + shell),
|
||||||
shell = "sh"
|
fmt.Sprintf("USER=" + user.Username),
|
||||||
|
fmt.Sprintf("HOME=" + user.HomeDir),
|
||||||
}
|
}
|
||||||
return shell
|
}
|
||||||
|
|
||||||
|
func acceptEnv(s string) bool {
|
||||||
|
split := strings.Split(s, "=")
|
||||||
|
if len(split) != 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
|
||||||
}
|
}
|
||||||
|
|
||||||
// sessionHandler handles SSH session post auth
|
// sessionHandler handles SSH session post auth
|
||||||
@ -118,14 +129,54 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
srv.sessions = append(srv.sessions, session)
|
srv.sessions = append(srv.sessions, session)
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := session.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
localUser, err := user.Lookup(session.User())
|
||||||
|
if err != nil {
|
||||||
|
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||||
|
err = session.Exit(1)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ptyReq, winCh, isPty := session.Pty()
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
if isPty {
|
if isPty {
|
||||||
cmd := exec.Command(getShellType())
|
loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%session", ptyReq.Term))
|
if err != nil {
|
||||||
|
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmd := exec.Command(loginCmd, loginArgs...)
|
||||||
|
go func() {
|
||||||
|
<-session.Context().Done()
|
||||||
|
err := cmd.Process.Kill()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
cmd.Dir = localUser.HomeDir
|
||||||
|
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||||
|
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
|
||||||
|
for _, v := range session.Environ() {
|
||||||
|
if acceptEnv(v) {
|
||||||
|
cmd.Env = append(cmd.Env, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
file, err := pty.Start(cmd)
|
file, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed starting SSH server %v", err)
|
log.Errorf("failed starting SSH server %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for win := range winCh {
|
for win := range winCh {
|
||||||
setWinSize(file, win.Width, win.Height)
|
setWinSize(file, win.Width, win.Height)
|
||||||
@ -181,3 +232,19 @@ func (srv *DefaultServer) Start() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUserShell(userID string) string {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
output, _ := exec.Command("getent", "passwd", userID).Output()
|
||||||
|
line := strings.SplitN(string(output), ":", 10)
|
||||||
|
if len(line) > 6 {
|
||||||
|
return strings.TrimSpace(line[6])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shell := os.Getenv("SHELL")
|
||||||
|
if shell == "" {
|
||||||
|
shell = "/bin/sh"
|
||||||
|
}
|
||||||
|
return shell
|
||||||
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
// SliceDiff returns the elements in slice `x` that are not in slice `y`
|
// SliceDiff returns the elements in slice `x` that are not in slice `y`
|
||||||
func SliceDiff(x, y []string) []string {
|
func SliceDiff(x, y []string) []string {
|
||||||
mapY := make(map[string]struct{}, len(y))
|
mapY := make(map[string]struct{}, len(y))
|
||||||
@ -14,3 +16,9 @@ func SliceDiff(x, y []string) []string {
|
|||||||
}
|
}
|
||||||
return diff
|
return diff
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FileExists returns true if specified file exists
|
||||||
|
func FileExists(path string) bool {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user