mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-23 19:21:23 +02:00
Merge pull request #983 from netbirdio/fix/ssh_connection_freeze
Fix ssh connection freeze
This commit is contained in:
commit
c000c05435
@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Print(err)
|
log.Debug(err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. " +
|
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
"You can verify the connection by running:\n\n" +
|
||||||
"Run the status command: \n\n" +
|
" netbird status\n\n")
|
||||||
" netbird status\n\n" +
|
return err
|
||||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
@ -2,9 +2,6 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/creack/pty"
|
|
||||||
"github.com/gliderlabs/ssh"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -13,11 +10,22 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/creack/pty"
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||||
const DefaultSSHPort = 44338
|
const DefaultSSHPort = 44338
|
||||||
|
|
||||||
|
// TerminalTimeout is the timeout for terminal session to be ready
|
||||||
|
const TerminalTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||||
|
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
// DefaultSSHServer is a function that creates DefaultServer
|
// DefaultSSHServer is a function that creates DefaultServer
|
||||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||||
return newDefaultServer(hostKeyPEM, addr)
|
return newDefaultServer(hostKeyPEM, addr)
|
||||||
@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||||
|
|
||||||
localUser, err := userNameLookup(session.User())
|
localUser, err := userNameLookup(session.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||||
@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Login command: %s", cmd.String())
|
||||||
file, err := pty.Start(cmd)
|
file, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed starting SSH server %v", err)
|
log.Errorf("failed starting SSH server %v", err)
|
||||||
@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Debugf("SSH session ended")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||||
@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
|||||||
// stdin
|
// stdin
|
||||||
_, err := io.Copy(file, session)
|
_, err := io.Copy(file, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = session.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||||
// stdout
|
timer := time.NewTimer(TerminalTimeout)
|
||||||
_, err := io.Copy(session, file)
|
for {
|
||||||
if err != nil {
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||||
|
_ = session.Exit(1)
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
// stdout
|
||||||
|
writtenBytes, err := io.Copy(session, file)
|
||||||
|
if err != nil && writtenBytes != 0 {
|
||||||
|
_ = session.Exit(0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(TerminalBackoffDelay)
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts SSH server. Blocking
|
// Start starts SSH server. Blocking
|
||||||
|
Loading…
x
Reference in New Issue
Block a user