streamrpc now requires net.Conn => use it instead of rwc everywhere

This commit is contained in:
Christian Schwarz 2018-08-08 13:09:51 +02:00
parent 1826535e6f
commit a0b320bfeb
8 changed files with 164 additions and 123 deletions

76
cmd/adaptors.go Normal file
View File

@ -0,0 +1,76 @@
package cmd
import (
"context"
"io"
"net"
"time"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/util"
)
type logNetConnConnecter struct {
streamrpc.Connecter
ReadDump, WriteDump string
}
var _ streamrpc.Connecter = logNetConnConnecter{}
func (l logNetConnConnecter) Connect(ctx context.Context) (net.Conn, error) {
conn, err := l.Connecter.Connect(ctx)
if err != nil {
return nil, err
}
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
}
type logListenerFactory struct {
ListenerFactory
ReadDump, WriteDump string
}
var _ ListenerFactory = logListenerFactory{}
type logListener struct {
net.Listener
ReadDump, WriteDump string
}
var _ net.Listener = logListener{}
func (m logListenerFactory) Listen() (net.Listener, error) {
l, err := m.ListenerFactory.Listen()
if err != nil {
return nil, err
}
return logListener{l, m.ReadDump, m.WriteDump}, nil
}
func (l logListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
}
type netsshAddr struct{}
func (netsshAddr) Network() string { return "netssh" }
func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
}
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -1,7 +1,7 @@
package cmd
import (
"io"
"net"
"fmt"
"github.com/pkg/errors"
@ -43,16 +43,8 @@ type JobDebugSettings struct {
}
}
type RWCConnecter interface {
Connect() (io.ReadWriteCloser, error)
}
type AuthenticatedChannelListenerFactory interface {
Listen() (AuthenticatedChannelListener, error)
}
type AuthenticatedChannelListener interface {
Accept() (ch io.ReadWriteCloser, err error)
Close() (err error)
type ListenerFactory interface {
Listen() (net.Listener, error)
}
type SSHStdinServerConnectDescr struct {

View File

@ -2,13 +2,14 @@ package cmd
import (
"fmt"
"io"
"net"
"context"
"github.com/jinzhu/copier"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/problame/go-netssh"
"github.com/problame/go-streamrpc"
"time"
)
@ -24,6 +25,8 @@ type SSHStdinserverConnecter struct {
dialTimeout time.Duration
}
var _ streamrpc.Connecter = &SSHStdinserverConnecter{}
func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverConnecter, err error) {
c = &SSHStdinserverConnecter{}
@ -46,21 +49,28 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
}
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
type netsshConnToConn struct { *netssh.SSHConn }
var _ net.Conn = netsshConnToConn{}
func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil }
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) {
var endpoint netssh.Endpoint
if err = copier.Copy(&endpoint, c); err != nil {
if err := copier.Copy(&endpoint, c); err != nil {
return nil, errors.WithStack(err)
}
var dialCtx context.Context
dialCtx, dialCancel := context.WithTimeout(context.TODO(), c.dialTimeout) // context.TODO tied to error handling below
dialCtx, dialCancel := context.WithTimeout(dialCtx, c.dialTimeout) // context.TODO tied to error handling below
defer dialCancel()
if rwc, err = netssh.Dial(dialCtx, endpoint); err != nil {
nconn, err := netssh.Dial(dialCtx, endpoint)
if err != nil {
if err == context.DeadlineExceeded {
err = errors.Errorf("dial_timeout of %s exceeded", c.dialTimeout)
}
err = errors.WithStack(err)
return
return nil, err
}
return
return netsshConnToConn{nconn}, nil
}

View File

@ -16,7 +16,7 @@ import (
type PullJob struct {
Name string
Connect RWCConnecter
Connect streamrpc.Connecter
Interval time.Duration
Mapping *DatasetMapFilter
// constructed from mapping during parsing
@ -90,6 +90,15 @@ func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j
return
}
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
logConnecter := logNetConnConnecter{
Connecter: j.Connect,
ReadDump: j.Debug.Conn.ReadDump,
WriteDump: j.Debug.Conn.WriteDump,
}
j.Connect = logConnecter
}
return
}
@ -132,56 +141,12 @@ var STREAMRPC_CONFIG = &streamrpc.ConnConfig{ // FIXME oversight and configurabi
RxStructuredMaxLen: 4096 * 4096,
RxStreamMaxChunkSize: 4096 * 4096,
TxChunkSize: 4096 * 4096,
}
type streamrpcRWCToNetConnAdatper struct {
io.ReadWriteCloser
}
func (streamrpcRWCToNetConnAdatper) LocalAddr() net.Addr {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) RemoteAddr() net.Addr {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetDeadline(t time.Time) error {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
type streamrpcRWCConnecterToNetConnAdapter struct {
RWCConnecter
ReadDump, WriteDump string
}
func (s streamrpcRWCConnecterToNetConnAdapter) Connect(ctx context.Context) (net.Conn, error) {
rwc, err := s.RWCConnecter.Connect()
if err != nil {
return nil, err
}
rwc, err = util.NewReadWriteCloserLogger(rwc, s.ReadDump, s.WriteDump)
if err != nil {
rwc.Close()
return nil, err
}
return streamrpcRWCToNetConnAdatper{rwc}, nil
}
type tcpConnecter struct {
d net.Dialer
}
func (t *tcpConnecter) Connect(ctx context.Context) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", "192.168.122.128:8888")
RxTimeout: streamrpc.Timeout{
Progress: 10*time.Second,
},
TxTimeout: streamrpc.Timeout{
Progress: 10*time.Second,
},
}
func (j *PullJob) doRun(ctx context.Context) {
@ -189,25 +154,12 @@ func (j *PullJob) doRun(ctx context.Context) {
j.task.Enter("run")
defer j.task.Finish()
//connecter := streamrpcRWCConnecterToNetConnAdapter{
// RWCConnecter: j.Connect,
// ReadDump: j.Debug.Conn.ReadDump,
// WriteDump: j.Debug.Conn.WriteDump,
//}
// FIXME
connecter := &tcpConnecter{net.Dialer{
Timeout: 2*time.Second,
}}
clientConf := &streamrpc.ClientConfig{
MaxConnectAttempts: 5, // FIXME
ReconnectBackoffBase: 1*time.Second,
ReconnectBackoffFactor: 2,
ConnConfig: STREAMRPC_CONFIG,
}
client, err := streamrpc.NewClient(connecter, clientConf)
client, err := streamrpc.NewClient(j.Connect, clientConf)
defer client.Close()
j.task.Enter("pull")

View File

@ -2,19 +2,17 @@ package cmd
import (
"context"
"io"
"time"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/util"
"github.com/problame/go-streamrpc"
"net"
)
type SourceJob struct {
Name string
Serve AuthenticatedChannelListenerFactory
Serve ListenerFactory
Filesystems *DatasetMapFilter
SnapshotPrefix string
Interval time.Duration
@ -70,6 +68,15 @@ func parseSourceJob(c JobParsingContext, name string, i map[string]interface{})
return
}
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
logServe := logListenerFactory{
ListenerFactory: j.Serve,
ReadDump: j.Debug.Conn.ReadDump,
WriteDump: j.Debug.Conn.WriteDump,
}
j.Serve = logServe
}
return
}
@ -139,19 +146,17 @@ func (j *SourceJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Pru
func (j *SourceJob) serve(ctx context.Context, task *Task) {
//listener, err := j.Serve.Listen()
// FIXME
listener, err := net.Listen("tcp", "192.168.122.128:8888")
listener, err := j.Serve.Listen()
if err != nil {
task.Log().WithError(err).Error("error listening")
return
}
type rwcChanMsg struct {
rwc io.ReadWriteCloser
err error
type connChanMsg struct {
conn net.Conn
err error
}
rwcChan := make(chan rwcChanMsg)
connChan := make(chan connChanMsg)
// Serve connections until interrupted or error
outer:
@ -160,23 +165,23 @@ outer:
go func() {
rwc, err := listener.Accept()
if err != nil {
rwcChan <- rwcChanMsg{rwc, err}
close(rwcChan)
connChan <- connChanMsg{rwc, err}
close(connChan)
return
}
rwcChan <- rwcChanMsg{rwc, err}
connChan <- connChanMsg{rwc, err}
}()
select {
case rwcMsg := <-rwcChan:
case rwcMsg := <-connChan:
if rwcMsg.err != nil {
task.Log().WithError(err).Error("error accepting connection")
break outer
}
j.handleConnection(rwcMsg.rwc, task)
j.handleConnection(rwcMsg.conn, task)
case <-ctx.Done():
task.Log().WithError(ctx.Err()).Info("context")
@ -197,17 +202,13 @@ outer:
}
func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) {
func (j *SourceJob) handleConnection(conn net.Conn, task *Task) {
task.Enter("handle_connection")
defer task.Finish()
task.Log().Info("handling client connection")
rwc, err := util.NewReadWriteCloserLogger(rwc, j.Debug.Conn.ReadDump, j.Debug.Conn.WriteDump)
if err != nil {
panic(err)
}
senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix))

View File

@ -11,6 +11,7 @@ import (
"regexp"
"strconv"
"time"
"github.com/problame/go-streamrpc"
)
var ConfigFileDefaultLocations []string = []string{
@ -208,7 +209,7 @@ func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error)
}
func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) {
func parseConnect(i map[string]interface{}) (c streamrpc.Connecter, err error) {
t, err := extractStringField(i, "type", true)
if err != nil {
@ -266,7 +267,7 @@ func parsePrunePolicy(v map[string]interface{}, willSeeBookmarks bool) (p PruneP
}
}
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p ListenerFactory, err error) {
t, err := extractStringField(v, "type", true)
if err != nil {

View File

@ -4,7 +4,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/problame/go-netssh"
"io"
"net"
"path"
)
@ -30,9 +30,9 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
return
}
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
if err = PreparePrivateSockpath(f.sockpath); err != nil {
if err := PreparePrivateSockpath(f.sockpath); err != nil {
return nil, err
}
@ -47,8 +47,16 @@ type StdinserverListener struct {
l *netssh.Listener
}
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
return l.l.Accept()
func (l StdinserverListener) Addr() net.Addr {
return netsshAddr{}
}
func (l StdinserverListener) Accept() (net.Conn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return netsshConnToNetConnAdatper{c}, nil
}
func (l StdinserverListener) Close() (err error) {

View File

@ -2,18 +2,19 @@ package util
import (
"io"
"net"
"os"
)
type ReadWriteCloserLogger struct {
RWC io.ReadWriteCloser
type NetConnLogger struct {
net.Conn
ReadFile *os.File
WriteFile *os.File
}
func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string) (l *ReadWriteCloserLogger, err error) {
l = &ReadWriteCloserLogger{
RWC: rwc,
func NewNetConnLogger(conn net.Conn, readlog, writelog string) (l *NetConnLogger, err error) {
l = &NetConnLogger{
Conn: conn,
}
flags := os.O_CREATE | os.O_WRONLY
if readlog != "" {
@ -29,8 +30,8 @@ func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string)
return
}
func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) {
n, err = c.RWC.Read(buf)
func (c *NetConnLogger) Read(buf []byte) (n int, err error) {
n, err = c.Conn.Read(buf)
if c.WriteFile != nil {
if _, writeErr := c.ReadFile.Write(buf[0:n]); writeErr != nil {
panic(writeErr)
@ -39,8 +40,8 @@ func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) {
return
}
func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) {
n, err = c.RWC.Write(buf)
func (c *NetConnLogger) Write(buf []byte) (n int, err error) {
n, err = c.Conn.Write(buf)
if c.ReadFile != nil {
if _, writeErr := c.WriteFile.Write(buf[0:n]); writeErr != nil {
panic(writeErr)
@ -48,8 +49,8 @@ func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) {
}
return
}
func (c *ReadWriteCloserLogger) Close() (err error) {
err = c.RWC.Close()
func (c *NetConnLogger) Close() (err error) {
err = c.Conn.Close()
if err != nil {
return
}