This commit is contained in:
Pascal Fischer 2023-06-23 12:20:14 +02:00
parent b524a9d49d
commit 2691e729cd
2 changed files with 46 additions and 18 deletions

View File

@ -9,7 +9,6 @@ import (
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
@ -73,7 +72,8 @@ var sshCmd = &cobra.Command{
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Print(err)
os.Exit(1)
// log.Print(err)
}
cancel()
}()
@ -92,11 +92,9 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. " +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
"Run the status command: \n\n" +
" netbird status\n\n" +
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return nil
}
go func() {

View File

@ -2,9 +2,6 @@ package ssh
import (
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io"
"net"
"os"
@ -13,6 +10,11 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@ -137,6 +139,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User())
if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@ -172,6 +176,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd)
if err != nil {
log.Errorf("failed starting SSH server %v", err)
@ -199,24 +204,49 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return
}
}
log.Debugf("SSH session ended")
}
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
go func() {
// stdin
_, err := io.Copy(file, session)
if err != nil {
return
}
io.Copy(file, session)
}()
// For nodes on AWS the terminal takes a while to be ready so we need to wait
terminalIsReady := make(chan bool)
go func() {
// stdout
_, err := io.Copy(session, file)
if err != nil {
return
for {
log.Debugf("Checking if terminal is ready")
if checkIfFileIsReady(file) {
terminalIsReady <- true
}
time.Sleep(100 * time.Millisecond)
}
}()
timer := time.NewTimer(30 * time.Second)
for {
select {
case <-timer.C:
session.Write([]byte("Reached timeout while opening connection\n"))
session.Exit(1)
case <-terminalIsReady:
// stdout
io.Copy(session, file)
session.Exit(0)
}
}
}
func checkIfFileIsReady(file *os.File) bool {
buffer := make([]byte, 0)
_, err := file.Read(buffer)
// _, err := file.Stat()
// log.Infof("file stat: %v", err)
if err == nil {
return true
}
return false
}
// Start starts SSH server. Blocking