This commit is contained in:
Pascal Fischer
2023-06-23 12:20:14 +02:00
parent b524a9d49d
commit 2691e729cd
2 changed files with 46 additions and 18 deletions

View File

@@ -2,9 +2,6 @@ package ssh
import (
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io"
"net"
"os"
@@ -13,6 +10,11 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@@ -137,6 +139,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User())
if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@@ -172,6 +176,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
}
}
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd)
if err != nil {
log.Errorf("failed starting SSH server %v", err)
@@ -199,24 +204,49 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return
}
}
log.Debugf("SSH session ended")
}
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
go func() {
// stdin
_, err := io.Copy(file, session)
if err != nil {
return
}
io.Copy(file, session)
}()
// For nodes on AWS the terminal takes a while to be ready so we need to wait
terminalIsReady := make(chan bool)
go func() {
// stdout
_, err := io.Copy(session, file)
if err != nil {
return
for {
log.Debugf("Checking if terminal is ready")
if checkIfFileIsReady(file) {
terminalIsReady <- true
}
time.Sleep(100 * time.Millisecond)
}
}()
timer := time.NewTimer(30 * time.Second)
for {
select {
case <-timer.C:
session.Write([]byte("Reached timeout while opening connection\n"))
session.Exit(1)
case <-terminalIsReady:
// stdout
io.Copy(session, file)
session.Exit(0)
}
}
}
func checkIfFileIsReady(file *os.File) bool {
buffer := make([]byte, 0)
_, err := file.Read(buffer)
// _, err := file.Stat()
// log.Infof("file stat: %v", err)
if err == nil {
return true
}
return false
}
// Start starts SSH server. Blocking