diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index d0f506ff8..dd9407738 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{ go func() { // blocking if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { - log.Print(err) + log.Debug(err) + os.Exit(1) } cancel() }() @@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) if err != nil { cmd.Printf("Error: %v\n", err) - cmd.Printf("Couldn't connect. " + - "You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + - "Run the status command: \n\n" + - " netbird status\n\n" + - "It might also be that the SSH server is disabled on the agent you are trying to connect to.\n") - return nil + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + + "You can verify the connection by running:\n\n" + + " netbird status\n\n") + return err } go func() { <-ctx.Done() diff --git a/client/ssh/server.go b/client/ssh/server.go index 5d63362b9..ae5c65c4a 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -2,9 +2,6 @@ package ssh import ( "fmt" - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - log "github.com/sirupsen/logrus" "io" "net" "os" @@ -13,11 +10,22 @@ import ( "runtime" "strings" "sync" + "time" + + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" ) // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server const DefaultSSHPort = 44338 +// TerminalTimeout is the timeout for terminal session to be ready +const TerminalTimeout = 10 * time.Second + +// TerminalBackoffDelay is the delay between terminal session readiness checks +const TerminalBackoffDelay = 500 * time.Millisecond + // DefaultSSHServer is a function that creates DefaultServer func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { return newDefaultServer(hostKeyPEM, addr) @@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { } }() + log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String()) + localUser, err := userNameLookup(session.User()) if err != nil { _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint @@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { } } + log.Debugf("Login command: %s", cmd.String()) file, err := pty.Start(cmd) if err != nil { log.Errorf("failed starting SSH server %v", err) @@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { return } } + log.Debugf("SSH session ended") } func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { @@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { // stdin _, err := io.Copy(file, session) if err != nil { + _ = session.Exit(1) return } }() - go func() { - // stdout - _, err := io.Copy(session, file) - if err != nil { + // AWS Linux 2 machines need some time to open the terminal so we need to wait for it + timer := time.NewTimer(TerminalTimeout) + for { + select { + case <-timer.C: + _, _ = session.Write([]byte("Reached timeout while opening connection\n")) + _ = session.Exit(1) return + default: + // stdout + writtenBytes, err := io.Copy(session, file) + if err != nil && writtenBytes != 0 { + _ = session.Exit(0) + return + } + time.Sleep(TerminalBackoffDelay) } - }() + } } // Start starts SSH server. Blocking