2017-09-11 13:48:07 +02:00
|
|
|
package cmd
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"github.com/ftrvxmtrx/fd"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"io"
|
2017-09-17 18:20:05 +02:00
|
|
|
"log"
|
2017-09-11 13:48:07 +02:00
|
|
|
"net"
|
|
|
|
)
|
|
|
|
|
|
|
|
var StdinserverCmd = &cobra.Command{
|
|
|
|
Use: "stdinserver CLIENT_IDENTITY",
|
|
|
|
Short: "start in stdinserver mode (from authorized_keys file)",
|
|
|
|
Run: cmdStdinServer,
|
|
|
|
}
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
RootCmd.AddCommand(StdinserverCmd)
|
|
|
|
}
|
|
|
|
|
|
|
|
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
|
|
|
|
2017-09-17 18:20:05 +02:00
|
|
|
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
|
|
|
|
|
|
|
die := func() {
|
|
|
|
log.Printf("stdinserver exiting after fatal error")
|
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
|
2017-09-22 14:02:07 +02:00
|
|
|
conf, err := ParseConfig(rootArgs.configFile)
|
2017-09-17 18:20:05 +02:00
|
|
|
if err != nil {
|
|
|
|
log.Printf("error parsing config: %s", err)
|
|
|
|
die()
|
|
|
|
}
|
2017-09-11 13:48:07 +02:00
|
|
|
|
|
|
|
if len(args) != 1 || args[0] == "" {
|
|
|
|
err = fmt.Errorf("must specify client_identity as positional argument")
|
2017-09-17 18:20:05 +02:00
|
|
|
die()
|
2017-09-11 13:48:07 +02:00
|
|
|
}
|
|
|
|
identity := args[0]
|
|
|
|
|
2017-09-17 18:20:05 +02:00
|
|
|
unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
|
2017-09-11 13:48:07 +02:00
|
|
|
if err != nil {
|
|
|
|
log.Printf("%s", err)
|
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Printf("opening control connection to zrepld via %s", unixaddr)
|
|
|
|
conn, err := net.DialUnix("unix", nil, unixaddr)
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("error connecting to zrepld: %s", err)
|
2017-09-17 18:20:05 +02:00
|
|
|
die()
|
2017-09-11 13:48:07 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
log.Printf("sending stdin and stdout fds to zrepld")
|
|
|
|
err = fd.Put(conn, os.Stdin, os.Stdout)
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("error: %s", err)
|
2017-09-17 18:20:05 +02:00
|
|
|
die()
|
2017-09-11 13:48:07 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
log.Printf("waiting for zrepld to close control connection")
|
|
|
|
for {
|
|
|
|
|
|
|
|
var buf [64]byte
|
|
|
|
n, err := conn.Read(buf[:])
|
|
|
|
if err == nil && n != 0 {
|
|
|
|
log.Printf("protocol error: read expected to timeout or EOF returned bytes")
|
|
|
|
}
|
|
|
|
|
|
|
|
if err == io.EOF {
|
|
|
|
log.Printf("zrepld closed control connection, terminating")
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
neterr, ok := err.(net.Error)
|
|
|
|
if !ok {
|
|
|
|
log.Printf("received unexpected error type: %T %s", err, err)
|
2017-09-17 18:20:05 +02:00
|
|
|
die()
|
2017-09-11 13:48:07 +02:00
|
|
|
}
|
|
|
|
if !neterr.Timeout() {
|
|
|
|
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
|
2017-09-17 18:20:05 +02:00
|
|
|
die()
|
2017-09-11 13:48:07 +02:00
|
|
|
}
|
|
|
|
// Read timed out, as expected
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|