diff --git a/go.mod b/go.mod index 95d5e85..798b084 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-sql-driver/mysql v1.4.1-0.20190907122137-b2c03bcae3d4 github.com/golang/protobuf v1.3.2 github.com/google/uuid v1.1.1 + github.com/hashicorp/yamux v0.0.0-20200609203250-aecfd211c9ce github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8 github.com/kr/pretty v0.1.0 github.com/lib/pq v1.2.0 @@ -33,6 +34,7 @@ require ( github.com/yudai/gojsondiff v0.0.0-20170107030110-7b1b7adf999d github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // go1.12 thinks it needs this github.com/zrepl/yaml-config v0.0.0-20191220194647-cbb6b0cf4bdd + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 golang.org/x/sync v0.0.0-20190423024810-112230192c58 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 @@ -40,3 +42,5 @@ require ( gonum.org/v1/gonum v0.7.0 // indirect google.golang.org/grpc v1.17.0 ) + +replace github.com/problame/go-netssh => /home/cs/zrepl/go-netssh diff --git a/go.sum b/go.sum index 0db95db..13c4ed2 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,8 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/gostaticanalysis/analysisutil v0.0.0-20190318220348-4088753ea4d3/go.mod h1:eEOZF4jCKGi+aprrirO9e7WKB3beBRtWgqGunKl6pKE= github.com/hashicorp/hcl v0.0.0-20180404174102-ef8a98b0bbce/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/yamux v0.0.0-20200609203250-aecfd211c9ce h1:7UnVY3T/ZnHUrfviiAgIUjg2PXxsQfs5bphsG8F7Keo= +github.com/hashicorp/yamux v0.0.0-20200609203250-aecfd211c9ce/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= @@ -300,6 +302,7 @@ github.com/zrepl/yaml-config v0.0.0-20191220194647-cbb6b0cf4bdd/go.mod h1:JmNwis github.com/zrepl/zrepl v0.2.0/go.mod h1:M3Zv2IGSO8iYpUjsZD6ayZ2LHy7zyMfzet9XatKOrZ8= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/zync/README.md b/zync/README.md new file mode 100644 index 0000000..e125ad1 --- /dev/null +++ b/zync/README.md @@ -0,0 +1,7 @@ +A tool that re-used zrepl abstractions to be an rsync-like sync tool for ZFS. + +Test environment: + +``` +go build -o zync && rsync -a ./zync root@192.168.124.233:/usr/local/bin/ && sudo ./zync local:///rpool/zrepltlstest/src ssh://root:%2Fhome%2Fcs%2Fzrepl%2Fzrepl%2Fzync%2Ftestid@192.168.124.233/p1/zync_sink +``` diff --git a/zync/transport/sshdirect/dial.go b/zync/transport/sshdirect/dial.go new file mode 100644 index 0000000..f148ac1 --- /dev/null +++ b/zync/transport/sshdirect/dial.go @@ -0,0 +1,406 @@ +package sshdirect + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os/exec" + "sync" + "syscall" + "time" + + "github.com/hashicorp/yamux" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/circlog" +) + +type Endpoint struct { + Host string + User string + Port uint16 + IdentityFile string + SSHCommand string + Options []string + RunCommand []string +} + +func (e Endpoint) CmdArgs() (cmd string, args []string, env []string) { + + if e.SSHCommand != "" { + cmd = e.SSHCommand + } else { + cmd = "ssh" + } + + args = make([]string, 0, 2*len(e.Options)+4) + args = append(args, + "-p", fmt.Sprintf("%d", e.Port), + "-T", + "-i", e.IdentityFile, + "-o", "BatchMode=yes", + ) + for _, option := range e.Options { + args = append(args, "-o", option) + } + args = append(args, fmt.Sprintf("%s@%s", e.User, e.Host)) + + args = append(args, e.RunCommand...) + + env = []string{} + + return +} + +type SSHConn struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + + shutdownMtx sync.Mutex + shutdownResult *shutdownResult // TODO not used anywhere + cmdCancel context.CancelFunc +} + +const go_network string = "netssh" + +type clientAddr struct { + pid int +} + +func (a clientAddr) Network() string { + return go_network +} + +func (a clientAddr) String() string { + return fmt.Sprintf("pid=%d", a.pid) +} + +func (conn *SSHConn) LocalAddr() net.Addr { + proc := conn.cmd.Process + if proc == nil { + return clientAddr{-1} + } + return clientAddr{proc.Pid} +} + +func (conn *SSHConn) RemoteAddr() net.Addr { + return conn.LocalAddr() +} + +// Read implements io.Reader. +// It returns *IOError for any non-nil error that is != io.EOF. +func (conn *SSHConn) Read(p []byte) (int, error) { + n, err := conn.stdout.Read(p) + if err != nil && err != io.EOF { + return n, &IOError{err} + } + return n, err +} + +// Write implements io.Writer. +// It returns *IOError for any error != nil. +func (conn *SSHConn) Write(p []byte) (int, error) { + n, err := conn.stdin.Write(p) + if err != nil { + return n, &IOError{err} + } + return n, err +} + +func (conn *SSHConn) CloseWrite() error { + return conn.stdin.Close() +} + +type deadliner interface { + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +func (conn *SSHConn) SetReadDeadline(t time.Time) error { + // type assertion is covered by test TestExecCmdPipesDeadlineBehavior + return conn.stdout.(deadliner).SetReadDeadline(t) +} + +func (conn *SSHConn) SetWriteDeadline(t time.Time) error { + // type assertion is covered by test TestExecCmdPipesDeadlineBehavior + return conn.stdin.(deadliner).SetWriteDeadline(t) +} + +func (conn *SSHConn) SetDeadline(t time.Time) error { + // try both + rerr := conn.SetReadDeadline(t) + werr := conn.SetWriteDeadline(t) + if rerr != nil { + return rerr + } + if werr != nil { + return werr + } + return nil +} + +func (conn *SSHConn) Close() error { + conn.shutdownProcess() + return nil // FIXME: waitError will be non-zero because we signaled it, shutdownProcess needs to distinguish that +} + +type shutdownResult struct { + waitErr error +} + +func (conn *SSHConn) shutdownProcess() *shutdownResult { + conn.shutdownMtx.Lock() + defer conn.shutdownMtx.Unlock() + + if conn.shutdownResult != nil { + return conn.shutdownResult + } + + wait := make(chan error, 1) + go func() { + if err := conn.cmd.Process.Signal(syscall.SIGTERM); err != nil { + // TODO log error + return + } + wait <- conn.cmd.Wait() + }() + + timeout := time.NewTimer(1 * time.Second) // FIXME const + defer timeout.Stop() + + select { + case waitErr := <-wait: + conn.shutdownResult = &shutdownResult{waitErr} + case <-timeout.C: + conn.cmdCancel() + waitErr := <-wait // reuse existing Wait invocation, must not call twice + conn.shutdownResult = &shutdownResult{waitErr} + } + return conn.shutdownResult +} + +// Cmd returns the underlying *exec.Cmd (the ssh client process) +// Use read-only, should not be necessary for regular users. +func (conn *SSHConn) Cmd() *exec.Cmd { + return conn.cmd +} + +// CmdCancel bypasses the normal shutdown mechanism of SSHConn +// (that is, calling Close) and cancels the process's context, +// which usually results in SIGKILL being sent to the process. +// Intended for integration tests, regular users shouldn't use it. +func (conn *SSHConn) CmdCancel() { + conn.cmdCancel() +} + +const bannerMessageLen = 31 + +var messages = make(map[string][]byte) + +func mustMessage(str string) []byte { + if len(str) > bannerMessageLen { + panic("message length must be smaller than bannerMessageLen") + } + if _, ok := messages[str]; ok { + panic("duplicate message") + } + var buf bytes.Buffer + n, _ := buf.WriteString(str) + if n != len(str) { + panic("message must only contain ascii / 8-bit chars") + } + buf.Write(bytes.Repeat([]byte{0}, bannerMessageLen-n)) + return buf.Bytes() +} + +var banner_msg = mustMessage("SSDIRECTHCON_HELO") +var proxy_error_msg = mustMessage("SSDIRECTHCON_PROXY_ERROR") /* FIXME irrelevant, was copy-pasta */ +var begin_msg = mustMessage("SSDIRECTHCON_BEGIN") + +type SSHError struct { + RWCError error + WhileActivity string +} + +// Error() will try to present a one-line error message unless ssh stderr output is longer than one line +func (e *SSHError) Error() string { + + exitErr, ok := e.RWCError.(*exec.ExitError) + if !ok { + return fmt.Sprintf("ssh: %s", e.RWCError) + } + + ws := exitErr.ProcessState.Sys().(syscall.WaitStatus) + var wsmsg string + if ws.Exited() { + wsmsg = fmt.Sprintf("(exit status %d)", ws.ExitStatus()) + } else { + wsmsg = fmt.Sprintf("(%s)", ws.Signal()) + } + + haveSSHMessage := len(exitErr.Stderr) > 0 + sshOnelineStderr := false + if i := bytes.Index(exitErr.Stderr, []byte("\n")); i == len(exitErr.Stderr)-1 { + sshOnelineStderr = true + } + stderr := bytes.TrimSpace(exitErr.Stderr) + + if haveSSHMessage { + if sshOnelineStderr { + return fmt.Sprintf("ssh: '%s' %s", stderr, wsmsg) // FIXME proper single-quoting + } else { + return fmt.Sprintf("ssh %s\n%s", wsmsg, stderr) + } + } + + return fmt.Sprintf("ssh terminated without stderr output %s", wsmsg) + +} + +type ProtocolError struct { + What string +} + +func (e ProtocolError) Error() string { + return e.What +} + +// Dial connects to the remote endpoint where it expects a command executing Proxy(). +// Dial performs a handshake consisting of the exchange of banner messages before returning the connection. +// If the handshake cannot be completed before dialCtx is Done(), the underlying ssh command is killed +// and the dialCtx.Err() returned. +// If the handshake completes, dialCtx's deadline does not affect the returned connection. +// +// Errors returned are either dialCtx.Err(), or intances of ProtocolError or *SSHError +func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) { + + sshCmd, sshArgs, sshEnv := endpoint.CmdArgs() + commandCtx, commandCancel := context.WithCancel(context.Background()) + cmd := exec.CommandContext(commandCtx, sshCmd, sshArgs...) + cmd.Env = sshEnv + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + + stderrBuf, err := circlog.NewCircularLog(1 << 15) + if err != nil { + panic(err) // wrong API usage + } + cmd.Stderr = stderrBuf + + if err = cmd.Start(); err != nil { + return nil, err + } + cmdWaitErrOrIOErr := func(ioErr error, what string) *SSHError { + werr := cmd.Wait() + if werr, ok := werr.(*exec.ExitError); ok { + werr.Stderr = []byte(stderrBuf.String()) + return &SSHError{werr, what} + } + return &SSHError{ioErr, what} + } + + confErrChan := make(chan error, 1) + go func() { + defer close(confErrChan) + var buf bytes.Buffer + if _, err := io.CopyN(&buf, stdout, int64(len(banner_msg))); err != nil { + confErrChan <- cmdWaitErrOrIOErr(err, "read banner") + return + } + resp := buf.Bytes() + switch { + case bytes.Equal(resp, banner_msg): + break + case bytes.Equal(resp, proxy_error_msg): + _ = cmdWaitErrOrIOErr(nil, "") + confErrChan <- ProtocolError{"proxy error, check remote configuration"} + return + default: + _ = cmdWaitErrOrIOErr(nil, "") + confErrChan <- ProtocolError{fmt.Sprintf("unknown banner message: %v", resp)} + return + } + buf.Reset() + buf.Write(begin_msg) + if _, err := io.Copy(stdin, &buf); err != nil { + confErrChan <- cmdWaitErrOrIOErr(err, "send begin message") + return + } + }() + + select { + case <-dialCtx.Done(): + + commandCancel() + // cancelling will make one of the calls in above goroutine fail, + // and the goroutine will send the error to confErrChan + // + // ignore the error and return the cancellation cause + + // draining always terminates because we know the channel is always closed + for _ = range confErrChan { + } + + // TODO collect stderr in this case + // can probably extend *SSHError for this but need to implement net.Error + + return nil, dialCtx.Err() + + case err := <-confErrChan: + if err != nil { + commandCancel() + return nil, err + } + } + + return &SSHConn{ + cmd: cmd, + stdin: stdin, + stdout: stdout, + cmdCancel: commandCancel, + }, nil +} + +type Connecter struct { + s *yamux.Session + endpoint Endpoint +} + +var _ transport.Connecter = (*Connecter)(nil) + +func NewConnecter(ctx context.Context, endpoint Endpoint) (*Connecter, error) { + conn, err := Dial(ctx, endpoint) + if err != nil { + return nil, err + } + s, err := yamux.Client(conn, nil) + if err != nil { + return nil, err + } + return &Connecter{ + s: s, + endpoint: endpoint, + }, nil +} + +type fakeWire struct { + net.Conn +} + +func (w *fakeWire) CloseWrite() error { + time.Sleep(1*time.Second) // HACKY + return fmt.Errorf("fakeWire does not support CloseWrite") +} + +func (c *Connecter) Connect(ctx context.Context) (transport.Wire, error) { + conn, err := c.s.Open() + return &fakeWire{conn}, err +} diff --git a/zync/transport/sshdirect/error.go b/zync/transport/sshdirect/error.go new file mode 100644 index 0000000..7d1a8b0 --- /dev/null +++ b/zync/transport/sshdirect/error.go @@ -0,0 +1,48 @@ +package sshdirect + +import ( + "fmt" + "net" + "os" + "syscall" +) + +type timeouter interface { + Timeout() bool +} + +var _ timeouter = &os.PathError{} + +type IOError struct { + Cause error +} + +var _ net.Error = &IOError{} + +func (e IOError) GoString() string { + return fmt.Sprintf("ServeConnIOError:%#v", e.Cause) +} + +func (e IOError) Error() string { + // following case found by experiment + if pathErr, ok := e.Cause.(*os.PathError); ok { + if pathErr.Err == syscall.EPIPE { + return fmt.Sprintf("netssh %s: %s (likely: connection reset by peer)", + pathErr.Op, pathErr.Err, + ) + } + return fmt.Sprintf("netssh: %s: %s", pathErr.Op, pathErr.Err) + } + return fmt.Sprintf("netssh: %s", e.Cause.Error()) +} + +func (e IOError) Timeout() bool { + if to, ok := e.Cause.(timeouter); ok { + return to.Timeout() + } + return false +} + +func (e IOError) Temporary() bool { + return false +} diff --git a/zync/transport/sshdirect/serve.go b/zync/transport/sshdirect/serve.go new file mode 100644 index 0000000..a161a33 --- /dev/null +++ b/zync/transport/sshdirect/serve.go @@ -0,0 +1,91 @@ +package sshdirect + +import ( + "bytes" + "io" + "log" + "net" + "os" + "time" + + "github.com/hashicorp/yamux" +) + +type ServeConn struct { + stdin, stdout *os.File +} + +var _ net.Conn = (*ServeConn)(nil) + +func ServeStdin() (net.Listener, error) { + + conn := &ServeConn{ + stdin: os.Stdin, + stdout: os.Stdout, + } + + var buf bytes.Buffer + buf.Write(banner_msg) + if _, err := io.Copy(conn, &buf); err != nil { + log.Printf("error sending confirm message: %s", err) + conn.Close() + return nil, err + } + buf.Reset() + if _, err := io.CopyN(&buf, conn, int64(len(begin_msg))); err != nil { + log.Printf("error reading begin message: %s", err) + conn.Close() + return nil, err + } + + return yamux.Server(conn, nil) +} + +func (c *ServeConn) Read(p []byte) (int, error) { + return c.stdin.Read(p) +} + +func (c *ServeConn) Write(p []byte) (int, error) { + return c.stdout.Write(p) +} + +func (f *ServeConn) Close() (err error) { + e1 := f.stdin.Close() + e2 := f.stdout.Close() + // FIXME merge errors + if e1 != nil { + return e1 + } + return e2 +} + +func (f *ServeConn) SetReadDeadline(t time.Time) error { + return f.stdin.SetReadDeadline(t) +} + +func (f *ServeConn) SetWriteDeadline(t time.Time) error { + return f.stdout.SetReadDeadline(t) +} + +func (f *ServeConn) SetDeadline(t time.Time) error { + // try both... + werr := f.SetWriteDeadline(t) + rerr := f.SetReadDeadline(t) + if werr != nil { + return werr + } + if rerr != nil { + return rerr + } + return nil +} + +type serveAddr struct{} + +const GoNetwork string = "sshdirect" + +func (serveAddr) Network() string { return GoNetwork } +func (serveAddr) String() string { return "???" } + +func (f *ServeConn) LocalAddr() net.Addr { return serveAddr{} } +func (f *ServeConn) RemoteAddr() net.Addr { return serveAddr{} } diff --git a/zync/transport/transportlistenerfromnetlistener/transportlistenerfromnetlistener.go b/zync/transport/transportlistenerfromnetlistener/transportlistenerfromnetlistener.go new file mode 100644 index 0000000..1def1ec --- /dev/null +++ b/zync/transport/transportlistenerfromnetlistener/transportlistenerfromnetlistener.go @@ -0,0 +1,47 @@ +package transportlistenerfromnetlistener + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/zrepl/zrepl/transport" +) + +type wrapFixed struct { + id string + l net.Listener +} + +var _ transport.AuthenticatedListener = (*wrapFixed)(nil) + +func WrapFixed(l net.Listener, identity string) transport.AuthenticatedListener { + return &wrapFixed{identity, l} +} + +func (w *wrapFixed) Addr() net.Addr { + return w.l.Addr() +} + +type fakeWire struct { + net.Conn +} + +func (w *fakeWire) CloseWrite() error { + time.Sleep(1*time.Second) // HACKY + return fmt.Errorf("fakeWire does not support CloseWrite") +} + +func (w *wrapFixed) Accept(ctx context.Context) (*transport.AuthConn, error) { + nc, err := w.l.Accept() + if err != nil { + return nil, err + } + + return transport.NewAuthConn(&fakeWire{nc}, w.id), nil +} + +func (w *wrapFixed) Close() error { + return w.l.Close() +} diff --git a/zync/zync.go b/zync/zync.go new file mode 100644 index 0000000..6cceeb9 --- /dev/null +++ b/zync/zync.go @@ -0,0 +1,333 @@ +package main + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "net/url" + "os" + "os/signal" + "runtime" + "strconv" + "syscall" + "time" + + "github.com/kr/pretty" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/daemon/filters" + "github.com/zrepl/zrepl/daemon/logging" + "github.com/zrepl/zrepl/daemon/logging/trace" + "github.com/zrepl/zrepl/endpoint" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/replication/logic" + "github.com/zrepl/zrepl/replication/logic/pdu" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/zfs" + "github.com/zrepl/zrepl/zync/transport/sshdirect" + "github.com/zrepl/zrepl/zync/transport/transportlistenerfromnetlistener" +) + +var flagStdinserver = flag.String("stdinserver", "", "") +var flagStderrToFile = flag.String("stderrtofile", "", "") + +const ( + modeSender = "sender" + modeReceiver = "receiver" +) + +func connecter(ctx context.Context, u *url.URL, mode, fs string) (transport.Connecter, error) { + switch u.Scheme { + case "ssh": + return connecterSSH(ctx, u, mode, fs) + default: + panic(fmt.Sprintf("unknown scheme %q", u.Scheme)) + } +} + +func connecterSSH(ctx context.Context, u *url.URL, mode, fs string) (transport.Connecter, error) { + var port uint16 + if u.Port() == "" { + port = 22 + } else { + portU64, err := strconv.ParseUint(u.Port(), 10, 16) + if err != nil { + return nil, errors.Wrap(err, "invalid port") + } + port = uint16(portU64) + } + + fmt.Println(u.User) + idFilePath, hasIdFile := u.User.Password() + if hasIdFile { + _, err := ioutil.ReadFile(idFilePath) + if err != nil { + fmt.Println(err) + hasIdFile = false + } + } + if !hasIdFile { + return nil, errors.New("must set password to identity file path") + } + + ep := sshdirect.Endpoint{ + Host: u.Hostname(), + Port: port, + User: u.User.Username(), + IdentityFile: idFilePath, + RunCommand: []string{"zync", "-stderrtofile", "/tmp/zync_server.log", "-stdinserver", mode, fs}, + } + return sshdirect.NewConnecter(ctx, ep) +} + +func onefsfilter(osname string) zfs.DatasetFilter { + f, err := filters.DatasetMapFilterFromConfig(map[string]bool{ + osname: true, + }) + if err != nil { + panic(err) + } + return f +} + +func makeEndpoint(mode, fs string) (interface{}, error) { + switch mode { + case modeSender: + sc := endpoint.SenderConfig{ + FSF: onefsfilter(fs), + Encrypt: &zfs.NilBool{B: true}, + JobID: endpoint.MustMakeJobID("sender"), + } + return endpoint.NewSender(sc), nil + case modeReceiver: + rpath, err := zfs.NewDatasetPath(fs) + if err != nil { + panic(err) + } + if err != nil { + panic(err) + } + return endpoint.NewReceiver(endpoint.ReceiverConfig{ + JobID: endpoint.MustMakeJobID("receiver"), + AppendClientIdentity: false, + RootWithoutClientComponent: rpath, + }), nil + default: + return nil, fmt.Errorf("unknown mode %q", mode) + } +} + +func serve(ctx context.Context, handler rpc.Handler) error { + + // copy-pasta from passive.go + ctxInterceptor := func(handlerCtx context.Context, info rpc.HandlerContextInterceptorData, handler func(ctx context.Context)) { + // the handlerCtx is clean => need to inherit logging and tracing config from job context + handlerCtx = logging.WithInherit(handlerCtx, ctx) + handlerCtx = trace.WithInherit(handlerCtx, ctx) + + handlerCtx, endTask := trace.WithTaskAndSpan(handlerCtx, "handler", fmt.Sprintf("method=%q", info.FullMethod())) + defer endTask() + handler(handlerCtx) + } + + l, err := sshdirect.ServeStdin() + if err != nil { + panic(err) + } + + srv := rpc.NewServer(handler, rpc.GetLoggersOrPanic(ctx), ctxInterceptor) + srv.Serve(ctx, transportlistenerfromnetlistener.WrapFixed(l, "fakeclientidentitymustnotbeempty")) + return nil +} + +func parseFSArg(ctx context.Context, mode, arg string) (interface{}, error) { + u, err := url.Parse(arg) + if err != nil { + return nil, err + } + if !u.IsAbs() { + return nil, fmt.Errorf("URL must be absolute, got %q", arg) + } + if u.Path[0] != '/' { + panic("impl error: expecting leading /") + } + fs := u.Path[1:] + switch u.Scheme { + case "local": + if u.Host != "" { + panic("hostname must be empty for 'local' scheme") + } + return makeEndpoint(mode, fs) + default: + cn, err := connecter(ctx, u, mode, fs) + if err != nil { + return nil, err + } + return rpc.NewClient(cn, rpc.GetLoggersOrPanic(ctx)), nil + } +} + +// func parseEndpoints(s, r string) (sender, receiver logic.Endpoint, _ error) { +// sUrl, err := url.Parse(s) +// if err != nil { +// return nil, nil, err +// } +// rUrl, err := url.Parse((r)) +// if err != nil { +// return nil, nil, err +// } +// if !sUrl.IsAbs() || !rUrl.IsAbs() { +// return nil, nil, fmt.Errorf("must have a scheme") +// } + +// if sUrl.Scheme == "local" {&& rUrl.Scheme == "local" { +// var err error +// sender, err = makeEndpoint(modeSender, sUrl.Path) +// if err == nil { +// receiver, err = makeEndpoint(modeReceiver, rUrl.Path) +// } +// return sender, receiver, err +// } else { + +// } +// } + +func main() { + + cancelSigs := make(chan os.Signal) + signal.Notify(cancelSigs, os.Interrupt, syscall.SIGTERM) + + ctx := context.Background() + trace.WithTaskFromStackUpdateCtx(&ctx) + + ctx = logging.WithLoggers(ctx, logging.SubsystemLoggersWithUniversalLogger(logger.NewStderrDebugLogger())) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + for { + select { + case <-cancelSigs: + cancel() + } + } + }() + + flag.Parse() + + if *flagStderrToFile != "" { + f, err := os.Create(*flagStderrToFile) + if err != nil { + panic(err) + } + syscall.Dup2(int(f.Fd()), int(os.Stderr.Fd())) + runtime.KeepAlive(f) + // enough? + } + + if *flagStdinserver != "" { + if flag.NArg() != 1 { + panic("usage: -stdinserver MODE FS") + } + h, err := makeEndpoint(*flagStdinserver, flag.Arg(0)) + if err != nil { + panic(err) + } + err = serve(ctx, h.(rpc.Handler)) + if err != nil { + panic(err) + } + return + } + + if flag.NArg() != 2 { + panic("usage: zync SENDER RECEIVER") + } + + sender, err := parseFSArg(ctx, modeSender, flag.Arg(0)) + if err != nil { + panic(err) + } + + receiver, err := parseFSArg(ctx, modeReceiver, flag.Arg(1)) + if err != nil { + panic(err) + } + + pp := logic.PlannerPolicy{ + EncryptedSend: logic.TriFromBool(true), // FIXME add flag + ReplicationConfig: pdu.ReplicationConfig{ + Protection: &pdu.ReplicationConfigProtection{ + Initial: pdu.ReplicationGuaranteeKind_GuaranteeNothing, + Incremental: pdu.ReplicationGuaranteeKind_GuaranteeNothing, + }, + }, + } + p := logic.NewPlanner(nil, nil, sender.(logic.Sender), receiver.(logic.Receiver), pp) + _, wait := replication.Do(ctx, p) + defer wait(true) + +} + +func local() { + + cancelSigs := make(chan os.Signal) + signal.Notify(cancelSigs, os.Interrupt, syscall.SIGTERM) + + ctx := context.Background() + trace.WithTaskFromStackUpdateCtx(&ctx) + + ctx = logging.WithLoggers(ctx, logging.SubsystemLoggersWithUniversalLogger(logger.NewStderrDebugLogger())) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + for { + select { + case <-cancelSigs: + cancel() + } + } + }() + + sc := endpoint.SenderConfig{ + FSF: onefsfilter(os.Args[1]), + Encrypt: &zfs.NilBool{B: true}, + JobID: endpoint.MustMakeJobID("sender"), + } + s := endpoint.NewSender(sc) + pretty.Println(s) + rpath, err := zfs.NewDatasetPath(os.Args[2]) + if err != nil { + panic(err) + } + r := endpoint.NewReceiver(endpoint.ReceiverConfig{ + JobID: endpoint.MustMakeJobID("receiver"), + AppendClientIdentity: false, + RootWithoutClientComponent: rpath, + }) + pretty.Println(r) + pp := logic.PlannerPolicy{ + EncryptedSend: logic.TriFromBool(sc.Encrypt.B), + ReplicationConfig: pdu.ReplicationConfig{ + Protection: &pdu.ReplicationConfigProtection{ + Initial: pdu.ReplicationGuaranteeKind_GuaranteeNothing, + Incremental: pdu.ReplicationGuaranteeKind_GuaranteeNothing, + }, + }, + } + p := logic.NewPlanner(nil, nil, s, r, pp) + report, wait := replication.Do(ctx, p) + defer wait(true) + + ticker := time.NewTicker(2 * time.Second) + for !wait(false) { + select { + case <-ticker.C: + // pretty.Println(report()) + } + } + pretty.Println(report()) +}