package ssh import ( "fmt" "net" "os" "time" "golang.org/x/crypto/ssh" "golang.org/x/term" ) // Client wraps crypto/ssh Client to simplify usage type Client struct { client *ssh.Client } // Close closes the wrapped SSH Client func (c *Client) Close() error { return c.client.Close() } // OpenTerminal starts an interactive terminal session with the remote SSH server func (c *Client) OpenTerminal() error { session, err := c.client.NewSession() if err != nil { return fmt.Errorf("failed to open new session: %v", err) } defer func() { err := session.Close() if err != nil { return } }() fd := int(os.Stdout.Fd()) state, err := term.MakeRaw(fd) if err != nil { return fmt.Errorf("failed to run raw terminal: %s", err) } defer func() { err := term.Restore(fd, state) if err != nil { return } }() w, h, err := term.GetSize(fd) if err != nil { return fmt.Errorf("terminal get size: %s", err) } modes := ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, } terminal := os.Getenv("TERM") if terminal == "" { terminal = "xterm-256color" } if err := session.RequestPty(terminal, h, w, modes); err != nil { return fmt.Errorf("failed requesting pty session with xterm: %s", err) } session.Stdout = os.Stdout session.Stderr = os.Stderr session.Stdin = os.Stdin if err := session.Shell(); err != nil { return fmt.Errorf("failed to start login shell on the remote host: %s", err) } if err := session.Wait(); err != nil { if e, ok := err.(*ssh.ExitError); ok { if e.ExitStatus() == 130 { return nil } } return fmt.Errorf("failed running SSH session: %s", err) } return nil } // DialWithKey connects to the remote SSH server with a provided private key file (PEM). func DialWithKey(addr, user string, privateKey []byte) (*Client, error) { signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { return nil, err } config := &ssh.ClientConfig{ User: user, Timeout: 5 * time.Second, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), } return Dial("tcp", addr, config) } // Dial connects to the remote SSH server. func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) { client, err := ssh.Dial(network, addr, config) if err != nil { return nil, err } return &Client{ client: client, }, nil }