From a0b320bfeba81bbe02d31e50d971c6e182786183 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 8 Aug 2018 13:09:51 +0200 Subject: [PATCH] streamrpc now requires net.Conn => use it instead of rwc everywhere --- cmd/adaptors.go | 76 ++++++++++++++++++++++++++++++ cmd/config.go | 14 ++---- cmd/config_connect.go | 28 +++++++---- cmd/config_job_pull.go | 82 +++++++-------------------------- cmd/config_job_source.go | 41 +++++++++-------- cmd/config_parse.go | 5 +- cmd/config_serve_stdinserver.go | 18 ++++++-- util/io.go | 23 ++++----- 8 files changed, 164 insertions(+), 123 deletions(-) create mode 100644 cmd/adaptors.go diff --git a/cmd/adaptors.go b/cmd/adaptors.go new file mode 100644 index 0000000..6423518 --- /dev/null +++ b/cmd/adaptors.go @@ -0,0 +1,76 @@ +package cmd + +import ( + "context" + "io" + "net" + "time" + + "github.com/problame/go-streamrpc" + "github.com/zrepl/zrepl/util" +) + +type logNetConnConnecter struct { + streamrpc.Connecter + ReadDump, WriteDump string +} + +var _ streamrpc.Connecter = logNetConnConnecter{} + +func (l logNetConnConnecter) Connect(ctx context.Context) (net.Conn, error) { + conn, err := l.Connecter.Connect(ctx) + if err != nil { + return nil, err + } + return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump) +} + +type logListenerFactory struct { + ListenerFactory + ReadDump, WriteDump string +} + +var _ ListenerFactory = logListenerFactory{} + +type logListener struct { + net.Listener + ReadDump, WriteDump string +} + +var _ net.Listener = logListener{} + +func (m logListenerFactory) Listen() (net.Listener, error) { + l, err := m.ListenerFactory.Listen() + if err != nil { + return nil, err + } + return logListener{l, m.ReadDump, m.WriteDump}, nil +} + +func (l logListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump) +} + + +type netsshAddr struct{} + +func (netsshAddr) Network() string { return "netssh" } +func (netsshAddr) String() string { return "???" } + +type netsshConnToNetConnAdatper struct { + io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn +} + +func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} } + +func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} } + +func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil } + +func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil } + +func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil } diff --git a/cmd/config.go b/cmd/config.go index 38a95be..1361d75 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -1,7 +1,7 @@ package cmd import ( - "io" + "net" "fmt" "github.com/pkg/errors" @@ -43,16 +43,8 @@ type JobDebugSettings struct { } } -type RWCConnecter interface { - Connect() (io.ReadWriteCloser, error) -} -type AuthenticatedChannelListenerFactory interface { - Listen() (AuthenticatedChannelListener, error) -} - -type AuthenticatedChannelListener interface { - Accept() (ch io.ReadWriteCloser, err error) - Close() (err error) +type ListenerFactory interface { + Listen() (net.Listener, error) } type SSHStdinServerConnectDescr struct { diff --git a/cmd/config_connect.go b/cmd/config_connect.go index b6d8435..d94f065 100644 --- a/cmd/config_connect.go +++ b/cmd/config_connect.go @@ -2,13 +2,14 @@ package cmd import ( "fmt" - "io" + "net" "context" "github.com/jinzhu/copier" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" "github.com/problame/go-netssh" + "github.com/problame/go-streamrpc" "time" ) @@ -24,6 +25,8 @@ type SSHStdinserverConnecter struct { dialTimeout time.Duration } +var _ streamrpc.Connecter = &SSHStdinserverConnecter{} + func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverConnecter, err error) { c = &SSHStdinserverConnecter{} @@ -46,21 +49,28 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo } -func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) { +type netsshConnToConn struct { *netssh.SSHConn } + +var _ net.Conn = netsshConnToConn{} + +func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil } +func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil } +func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil } + +func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) { var endpoint netssh.Endpoint - if err = copier.Copy(&endpoint, c); err != nil { + if err := copier.Copy(&endpoint, c); err != nil { return nil, errors.WithStack(err) } - var dialCtx context.Context - dialCtx, dialCancel := context.WithTimeout(context.TODO(), c.dialTimeout) // context.TODO tied to error handling below + dialCtx, dialCancel := context.WithTimeout(dialCtx, c.dialTimeout) // context.TODO tied to error handling below defer dialCancel() - if rwc, err = netssh.Dial(dialCtx, endpoint); err != nil { + nconn, err := netssh.Dial(dialCtx, endpoint) + if err != nil { if err == context.DeadlineExceeded { err = errors.Errorf("dial_timeout of %s exceeded", c.dialTimeout) } - err = errors.WithStack(err) - return + return nil, err } - return + return netsshConnToConn{nconn}, nil } diff --git a/cmd/config_job_pull.go b/cmd/config_job_pull.go index cb56ee5..5b4c112 100644 --- a/cmd/config_job_pull.go +++ b/cmd/config_job_pull.go @@ -16,7 +16,7 @@ import ( type PullJob struct { Name string - Connect RWCConnecter + Connect streamrpc.Connecter Interval time.Duration Mapping *DatasetMapFilter // constructed from mapping during parsing @@ -90,6 +90,15 @@ func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j return } + if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" { + logConnecter := logNetConnConnecter{ + Connecter: j.Connect, + ReadDump: j.Debug.Conn.ReadDump, + WriteDump: j.Debug.Conn.WriteDump, + } + j.Connect = logConnecter + } + return } @@ -132,56 +141,12 @@ var STREAMRPC_CONFIG = &streamrpc.ConnConfig{ // FIXME oversight and configurabi RxStructuredMaxLen: 4096 * 4096, RxStreamMaxChunkSize: 4096 * 4096, TxChunkSize: 4096 * 4096, -} - -type streamrpcRWCToNetConnAdatper struct { - io.ReadWriteCloser -} - -func (streamrpcRWCToNetConnAdatper) LocalAddr() net.Addr { - panic("implement me") -} - -func (streamrpcRWCToNetConnAdatper) RemoteAddr() net.Addr { - panic("implement me") -} - -func (streamrpcRWCToNetConnAdatper) SetDeadline(t time.Time) error { - panic("implement me") -} - -func (streamrpcRWCToNetConnAdatper) SetReadDeadline(t time.Time) error { - panic("implement me") -} - -func (streamrpcRWCToNetConnAdatper) SetWriteDeadline(t time.Time) error { - panic("implement me") -} - -type streamrpcRWCConnecterToNetConnAdapter struct { - RWCConnecter - ReadDump, WriteDump string -} - -func (s streamrpcRWCConnecterToNetConnAdapter) Connect(ctx context.Context) (net.Conn, error) { - rwc, err := s.RWCConnecter.Connect() - if err != nil { - return nil, err - } - rwc, err = util.NewReadWriteCloserLogger(rwc, s.ReadDump, s.WriteDump) - if err != nil { - rwc.Close() - return nil, err - } - return streamrpcRWCToNetConnAdatper{rwc}, nil -} - -type tcpConnecter struct { - d net.Dialer -} - -func (t *tcpConnecter) Connect(ctx context.Context) (net.Conn, error) { - return t.d.DialContext(ctx, "tcp", "192.168.122.128:8888") + RxTimeout: streamrpc.Timeout{ + Progress: 10*time.Second, + }, + TxTimeout: streamrpc.Timeout{ + Progress: 10*time.Second, + }, } func (j *PullJob) doRun(ctx context.Context) { @@ -189,25 +154,12 @@ func (j *PullJob) doRun(ctx context.Context) { j.task.Enter("run") defer j.task.Finish() - //connecter := streamrpcRWCConnecterToNetConnAdapter{ - // RWCConnecter: j.Connect, - // ReadDump: j.Debug.Conn.ReadDump, - // WriteDump: j.Debug.Conn.WriteDump, - //} - // FIXME - connecter := &tcpConnecter{net.Dialer{ - Timeout: 2*time.Second, - }} - clientConf := &streamrpc.ClientConfig{ - MaxConnectAttempts: 5, // FIXME - ReconnectBackoffBase: 1*time.Second, - ReconnectBackoffFactor: 2, ConnConfig: STREAMRPC_CONFIG, } - client, err := streamrpc.NewClient(connecter, clientConf) + client, err := streamrpc.NewClient(j.Connect, clientConf) defer client.Close() j.task.Enter("pull") diff --git a/cmd/config_job_source.go b/cmd/config_job_source.go index 6c98cab..d212fa5 100644 --- a/cmd/config_job_source.go +++ b/cmd/config_job_source.go @@ -2,19 +2,17 @@ package cmd import ( "context" - "io" "time" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" - "github.com/zrepl/zrepl/util" "github.com/problame/go-streamrpc" "net" ) type SourceJob struct { Name string - Serve AuthenticatedChannelListenerFactory + Serve ListenerFactory Filesystems *DatasetMapFilter SnapshotPrefix string Interval time.Duration @@ -70,6 +68,15 @@ func parseSourceJob(c JobParsingContext, name string, i map[string]interface{}) return } + if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" { + logServe := logListenerFactory{ + ListenerFactory: j.Serve, + ReadDump: j.Debug.Conn.ReadDump, + WriteDump: j.Debug.Conn.WriteDump, + } + j.Serve = logServe + } + return } @@ -139,19 +146,17 @@ func (j *SourceJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Pru func (j *SourceJob) serve(ctx context.Context, task *Task) { - //listener, err := j.Serve.Listen() - // FIXME - listener, err := net.Listen("tcp", "192.168.122.128:8888") + listener, err := j.Serve.Listen() if err != nil { task.Log().WithError(err).Error("error listening") return } - type rwcChanMsg struct { - rwc io.ReadWriteCloser - err error + type connChanMsg struct { + conn net.Conn + err error } - rwcChan := make(chan rwcChanMsg) + connChan := make(chan connChanMsg) // Serve connections until interrupted or error outer: @@ -160,23 +165,23 @@ outer: go func() { rwc, err := listener.Accept() if err != nil { - rwcChan <- rwcChanMsg{rwc, err} - close(rwcChan) + connChan <- connChanMsg{rwc, err} + close(connChan) return } - rwcChan <- rwcChanMsg{rwc, err} + connChan <- connChanMsg{rwc, err} }() select { - case rwcMsg := <-rwcChan: + case rwcMsg := <-connChan: if rwcMsg.err != nil { task.Log().WithError(err).Error("error accepting connection") break outer } - j.handleConnection(rwcMsg.rwc, task) + j.handleConnection(rwcMsg.conn, task) case <-ctx.Done(): task.Log().WithError(ctx.Err()).Info("context") @@ -197,17 +202,13 @@ outer: } -func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) { +func (j *SourceJob) handleConnection(conn net.Conn, task *Task) { task.Enter("handle_connection") defer task.Finish() task.Log().Info("handling client connection") - rwc, err := util.NewReadWriteCloserLogger(rwc, j.Debug.Conn.ReadDump, j.Debug.Conn.WriteDump) - if err != nil { - panic(err) - } senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix)) diff --git a/cmd/config_parse.go b/cmd/config_parse.go index cd733eb..6f6784c 100644 --- a/cmd/config_parse.go +++ b/cmd/config_parse.go @@ -11,6 +11,7 @@ import ( "regexp" "strconv" "time" + "github.com/problame/go-streamrpc" ) var ConfigFileDefaultLocations []string = []string{ @@ -208,7 +209,7 @@ func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error) } -func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) { +func parseConnect(i map[string]interface{}) (c streamrpc.Connecter, err error) { t, err := extractStringField(i, "type", true) if err != nil { @@ -266,7 +267,7 @@ func parsePrunePolicy(v map[string]interface{}, willSeeBookmarks bool) (p PruneP } } -func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) { +func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p ListenerFactory, err error) { t, err := extractStringField(v, "type", true) if err != nil { diff --git a/cmd/config_serve_stdinserver.go b/cmd/config_serve_stdinserver.go index ed7b68d..2380cc8 100644 --- a/cmd/config_serve_stdinserver.go +++ b/cmd/config_serve_stdinserver.go @@ -4,7 +4,7 @@ import ( "github.com/mitchellh/mapstructure" "github.com/pkg/errors" "github.com/problame/go-netssh" - "io" + "net" "path" ) @@ -30,9 +30,9 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface return } -func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) { +func (f *StdinserverListenerFactory) Listen() (net.Listener, error) { - if err = PreparePrivateSockpath(f.sockpath); err != nil { + if err := PreparePrivateSockpath(f.sockpath); err != nil { return nil, err } @@ -47,8 +47,16 @@ type StdinserverListener struct { l *netssh.Listener } -func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) { - return l.l.Accept() +func (l StdinserverListener) Addr() net.Addr { + return netsshAddr{} +} + +func (l StdinserverListener) Accept() (net.Conn, error) { + c, err := l.l.Accept() + if err != nil { + return nil, err + } + return netsshConnToNetConnAdatper{c}, nil } func (l StdinserverListener) Close() (err error) { diff --git a/util/io.go b/util/io.go index 68ae286..857df26 100644 --- a/util/io.go +++ b/util/io.go @@ -2,18 +2,19 @@ package util import ( "io" + "net" "os" ) -type ReadWriteCloserLogger struct { - RWC io.ReadWriteCloser +type NetConnLogger struct { + net.Conn ReadFile *os.File WriteFile *os.File } -func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string) (l *ReadWriteCloserLogger, err error) { - l = &ReadWriteCloserLogger{ - RWC: rwc, +func NewNetConnLogger(conn net.Conn, readlog, writelog string) (l *NetConnLogger, err error) { + l = &NetConnLogger{ + Conn: conn, } flags := os.O_CREATE | os.O_WRONLY if readlog != "" { @@ -29,8 +30,8 @@ func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string) return } -func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) { - n, err = c.RWC.Read(buf) +func (c *NetConnLogger) Read(buf []byte) (n int, err error) { + n, err = c.Conn.Read(buf) if c.WriteFile != nil { if _, writeErr := c.ReadFile.Write(buf[0:n]); writeErr != nil { panic(writeErr) @@ -39,8 +40,8 @@ func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) { return } -func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) { - n, err = c.RWC.Write(buf) +func (c *NetConnLogger) Write(buf []byte) (n int, err error) { + n, err = c.Conn.Write(buf) if c.ReadFile != nil { if _, writeErr := c.WriteFile.Write(buf[0:n]); writeErr != nil { panic(writeErr) @@ -48,8 +49,8 @@ func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) { } return } -func (c *ReadWriteCloserLogger) Close() (err error) { - err = c.RWC.Close() +func (c *NetConnLogger) Close() (err error) { + err = c.Conn.Close() if err != nil { return }