Merge pull request #983 from netbirdio/fix/ssh_connection_freeze

Fix ssh connection freeze
This commit is contained in:
pascal-fischer 2023-06-27 18:10:30 +02:00 committed by GitHub
commit c000c05435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 15 deletions

View File

@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
go func() { go func() {
// blocking // blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Print(err) log.Debug(err)
os.Exit(1)
} }
cancel() cancel()
}() }()
@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil { if err != nil {
cmd.Printf("Error: %v\n", err) cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. " + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + "You can verify the connection by running:\n\n" +
"Run the status command: \n\n" + " netbird status\n\n")
" netbird status\n\n" + return err
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil
} }
go func() { go func() {
<-ctx.Done() <-ctx.Done()

View File

@ -2,9 +2,6 @@ package ssh
import ( import (
"fmt" "fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io" "io"
"net" "net"
"os" "os"
@ -13,11 +10,22 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
) )
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338 const DefaultSSHPort = 44338
// TerminalTimeout is the timeout for terminal session to be ready
const TerminalTimeout = 10 * time.Second
// TerminalBackoffDelay is the delay between terminal session readiness checks
const TerminalBackoffDelay = 500 * time.Millisecond
// DefaultSSHServer is a function that creates DefaultServer // DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr) return newDefaultServer(hostKeyPEM, addr)
@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
}() }()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User()) localUser, err := userNameLookup(session.User())
if err != nil { if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
} }
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd) file, err := pty.Start(cmd)
if err != nil { if err != nil {
log.Errorf("failed starting SSH server %v", err) log.Errorf("failed starting SSH server %v", err)
@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return return
} }
} }
log.Debugf("SSH session ended")
} }
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
// stdin // stdin
_, err := io.Copy(file, session) _, err := io.Copy(file, session)
if err != nil { if err != nil {
_ = session.Exit(1)
return return
} }
}() }()
go func() { // AWS Linux 2 machines need some time to open the terminal so we need to wait for it
// stdout timer := time.NewTimer(TerminalTimeout)
_, err := io.Copy(session, file) for {
if err != nil { select {
case <-timer.C:
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
_ = session.Exit(1)
return return
default:
// stdout
writtenBytes, err := io.Copy(session, file)
if err != nil && writtenBytes != 0 {
_ = session.Exit(0)
return
}
time.Sleep(TerminalBackoffDelay)
} }
}() }
} }
// Start starts SSH server. Blocking // Start starts SSH server. Blocking