Merge pull request #983 from netbirdio/fix/ssh_connection_freeze

Fix ssh connection freeze
This commit is contained in:
pascal-fischer 2023-06-27 18:10:30 +02:00 committed by GitHub
commit c000c05435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 15 deletions

View File

@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Print(err)
log.Debug(err)
os.Exit(1)
}
cancel()
}()
@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. " +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
"Run the status command: \n\n" +
" netbird status\n\n" +
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}
go func() {
<-ctx.Done()

View File

@ -2,9 +2,6 @@ package ssh
import (
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io"
"net"
"os"
@ -13,11 +10,22 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338
// TerminalTimeout is the timeout for terminal session to be ready
const TerminalTimeout = 10 * time.Second
// TerminalBackoffDelay is the delay between terminal session readiness checks
const TerminalBackoffDelay = 500 * time.Millisecond
// DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr)
@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User())
if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd)
if err != nil {
log.Errorf("failed starting SSH server %v", err)
@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return
}
}
log.Debugf("SSH session ended")
}
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
// stdin
_, err := io.Copy(file, session)
if err != nil {
_ = session.Exit(1)
return
}
}()
go func() {
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
timer := time.NewTimer(TerminalTimeout)
for {
select {
case <-timer.C:
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
_ = session.Exit(1)
return
default:
// stdout
_, err := io.Copy(session, file)
if err != nil {
writtenBytes, err := io.Copy(session, file)
if err != nil && writtenBytes != 0 {
_ = session.Exit(0)
return
}
}()
time.Sleep(TerminalBackoffDelay)
}
}
}
// Start starts SSH server. Blocking