streamrpc now requires net.Conn => use it instead of rwc everywhere

This commit is contained in:
Christian Schwarz 2018-08-08 13:09:51 +02:00
parent 1826535e6f
commit a0b320bfeb
8 changed files with 164 additions and 123 deletions

76
cmd/adaptors.go Normal file
View File

@ -0,0 +1,76 @@
package cmd
import (
"context"
"io"
"net"
"time"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/util"
)
type logNetConnConnecter struct {
streamrpc.Connecter
ReadDump, WriteDump string
}
var _ streamrpc.Connecter = logNetConnConnecter{}
func (l logNetConnConnecter) Connect(ctx context.Context) (net.Conn, error) {
conn, err := l.Connecter.Connect(ctx)
if err != nil {
return nil, err
}
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
}
type logListenerFactory struct {
ListenerFactory
ReadDump, WriteDump string
}
var _ ListenerFactory = logListenerFactory{}
type logListener struct {
net.Listener
ReadDump, WriteDump string
}
var _ net.Listener = logListener{}
func (m logListenerFactory) Listen() (net.Listener, error) {
l, err := m.ListenerFactory.Listen()
if err != nil {
return nil, err
}
return logListener{l, m.ReadDump, m.WriteDump}, nil
}
func (l logListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
}
type netsshAddr struct{}
func (netsshAddr) Network() string { return "netssh" }
func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
}
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -1,7 +1,7 @@
package cmd package cmd
import ( import (
"io" "net"
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -43,16 +43,8 @@ type JobDebugSettings struct {
} }
} }
type RWCConnecter interface { type ListenerFactory interface {
Connect() (io.ReadWriteCloser, error) Listen() (net.Listener, error)
}
type AuthenticatedChannelListenerFactory interface {
Listen() (AuthenticatedChannelListener, error)
}
type AuthenticatedChannelListener interface {
Accept() (ch io.ReadWriteCloser, err error)
Close() (err error)
} }
type SSHStdinServerConnectDescr struct { type SSHStdinServerConnectDescr struct {

View File

@ -2,13 +2,14 @@ package cmd
import ( import (
"fmt" "fmt"
"io" "net"
"context" "context"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-netssh" "github.com/problame/go-netssh"
"github.com/problame/go-streamrpc"
"time" "time"
) )
@ -24,6 +25,8 @@ type SSHStdinserverConnecter struct {
dialTimeout time.Duration dialTimeout time.Duration
} }
var _ streamrpc.Connecter = &SSHStdinserverConnecter{}
func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverConnecter, err error) { func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverConnecter, err error) {
c = &SSHStdinserverConnecter{} c = &SSHStdinserverConnecter{}
@ -46,21 +49,28 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
} }
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) { type netsshConnToConn struct { *netssh.SSHConn }
var _ net.Conn = netsshConnToConn{}
func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil }
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) {
var endpoint netssh.Endpoint var endpoint netssh.Endpoint
if err = copier.Copy(&endpoint, c); err != nil { if err := copier.Copy(&endpoint, c); err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
var dialCtx context.Context dialCtx, dialCancel := context.WithTimeout(dialCtx, c.dialTimeout) // context.TODO tied to error handling below
dialCtx, dialCancel := context.WithTimeout(context.TODO(), c.dialTimeout) // context.TODO tied to error handling below
defer dialCancel() defer dialCancel()
if rwc, err = netssh.Dial(dialCtx, endpoint); err != nil { nconn, err := netssh.Dial(dialCtx, endpoint)
if err != nil {
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
err = errors.Errorf("dial_timeout of %s exceeded", c.dialTimeout) err = errors.Errorf("dial_timeout of %s exceeded", c.dialTimeout)
} }
err = errors.WithStack(err) return nil, err
return
} }
return return netsshConnToConn{nconn}, nil
} }

View File

@ -16,7 +16,7 @@ import (
type PullJob struct { type PullJob struct {
Name string Name string
Connect RWCConnecter Connect streamrpc.Connecter
Interval time.Duration Interval time.Duration
Mapping *DatasetMapFilter Mapping *DatasetMapFilter
// constructed from mapping during parsing // constructed from mapping during parsing
@ -90,6 +90,15 @@ func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j
return return
} }
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
logConnecter := logNetConnConnecter{
Connecter: j.Connect,
ReadDump: j.Debug.Conn.ReadDump,
WriteDump: j.Debug.Conn.WriteDump,
}
j.Connect = logConnecter
}
return return
} }
@ -132,56 +141,12 @@ var STREAMRPC_CONFIG = &streamrpc.ConnConfig{ // FIXME oversight and configurabi
RxStructuredMaxLen: 4096 * 4096, RxStructuredMaxLen: 4096 * 4096,
RxStreamMaxChunkSize: 4096 * 4096, RxStreamMaxChunkSize: 4096 * 4096,
TxChunkSize: 4096 * 4096, TxChunkSize: 4096 * 4096,
} RxTimeout: streamrpc.Timeout{
Progress: 10*time.Second,
type streamrpcRWCToNetConnAdatper struct { },
io.ReadWriteCloser TxTimeout: streamrpc.Timeout{
} Progress: 10*time.Second,
},
func (streamrpcRWCToNetConnAdatper) LocalAddr() net.Addr {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) RemoteAddr() net.Addr {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetDeadline(t time.Time) error {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (streamrpcRWCToNetConnAdatper) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
type streamrpcRWCConnecterToNetConnAdapter struct {
RWCConnecter
ReadDump, WriteDump string
}
func (s streamrpcRWCConnecterToNetConnAdapter) Connect(ctx context.Context) (net.Conn, error) {
rwc, err := s.RWCConnecter.Connect()
if err != nil {
return nil, err
}
rwc, err = util.NewReadWriteCloserLogger(rwc, s.ReadDump, s.WriteDump)
if err != nil {
rwc.Close()
return nil, err
}
return streamrpcRWCToNetConnAdatper{rwc}, nil
}
type tcpConnecter struct {
d net.Dialer
}
func (t *tcpConnecter) Connect(ctx context.Context) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", "192.168.122.128:8888")
} }
func (j *PullJob) doRun(ctx context.Context) { func (j *PullJob) doRun(ctx context.Context) {
@ -189,25 +154,12 @@ func (j *PullJob) doRun(ctx context.Context) {
j.task.Enter("run") j.task.Enter("run")
defer j.task.Finish() defer j.task.Finish()
//connecter := streamrpcRWCConnecterToNetConnAdapter{
// RWCConnecter: j.Connect,
// ReadDump: j.Debug.Conn.ReadDump,
// WriteDump: j.Debug.Conn.WriteDump,
//}
// FIXME // FIXME
connecter := &tcpConnecter{net.Dialer{
Timeout: 2*time.Second,
}}
clientConf := &streamrpc.ClientConfig{ clientConf := &streamrpc.ClientConfig{
MaxConnectAttempts: 5, // FIXME
ReconnectBackoffBase: 1*time.Second,
ReconnectBackoffFactor: 2,
ConnConfig: STREAMRPC_CONFIG, ConnConfig: STREAMRPC_CONFIG,
} }
client, err := streamrpc.NewClient(connecter, clientConf) client, err := streamrpc.NewClient(j.Connect, clientConf)
defer client.Close() defer client.Close()
j.task.Enter("pull") j.task.Enter("pull")

View File

@ -2,19 +2,17 @@ package cmd
import ( import (
"context" "context"
"io"
"time" "time"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/zrepl/zrepl/util"
"github.com/problame/go-streamrpc" "github.com/problame/go-streamrpc"
"net" "net"
) )
type SourceJob struct { type SourceJob struct {
Name string Name string
Serve AuthenticatedChannelListenerFactory Serve ListenerFactory
Filesystems *DatasetMapFilter Filesystems *DatasetMapFilter
SnapshotPrefix string SnapshotPrefix string
Interval time.Duration Interval time.Duration
@ -70,6 +68,15 @@ func parseSourceJob(c JobParsingContext, name string, i map[string]interface{})
return return
} }
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
logServe := logListenerFactory{
ListenerFactory: j.Serve,
ReadDump: j.Debug.Conn.ReadDump,
WriteDump: j.Debug.Conn.WriteDump,
}
j.Serve = logServe
}
return return
} }
@ -139,19 +146,17 @@ func (j *SourceJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Pru
func (j *SourceJob) serve(ctx context.Context, task *Task) { func (j *SourceJob) serve(ctx context.Context, task *Task) {
//listener, err := j.Serve.Listen() listener, err := j.Serve.Listen()
// FIXME
listener, err := net.Listen("tcp", "192.168.122.128:8888")
if err != nil { if err != nil {
task.Log().WithError(err).Error("error listening") task.Log().WithError(err).Error("error listening")
return return
} }
type rwcChanMsg struct { type connChanMsg struct {
rwc io.ReadWriteCloser conn net.Conn
err error err error
} }
rwcChan := make(chan rwcChanMsg) connChan := make(chan connChanMsg)
// Serve connections until interrupted or error // Serve connections until interrupted or error
outer: outer:
@ -160,23 +165,23 @@ outer:
go func() { go func() {
rwc, err := listener.Accept() rwc, err := listener.Accept()
if err != nil { if err != nil {
rwcChan <- rwcChanMsg{rwc, err} connChan <- connChanMsg{rwc, err}
close(rwcChan) close(connChan)
return return
} }
rwcChan <- rwcChanMsg{rwc, err} connChan <- connChanMsg{rwc, err}
}() }()
select { select {
case rwcMsg := <-rwcChan: case rwcMsg := <-connChan:
if rwcMsg.err != nil { if rwcMsg.err != nil {
task.Log().WithError(err).Error("error accepting connection") task.Log().WithError(err).Error("error accepting connection")
break outer break outer
} }
j.handleConnection(rwcMsg.rwc, task) j.handleConnection(rwcMsg.conn, task)
case <-ctx.Done(): case <-ctx.Done():
task.Log().WithError(ctx.Err()).Info("context") task.Log().WithError(ctx.Err()).Info("context")
@ -197,17 +202,13 @@ outer:
} }
func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) { func (j *SourceJob) handleConnection(conn net.Conn, task *Task) {
task.Enter("handle_connection") task.Enter("handle_connection")
defer task.Finish() defer task.Finish()
task.Log().Info("handling client connection") task.Log().Info("handling client connection")
rwc, err := util.NewReadWriteCloserLogger(rwc, j.Debug.Conn.ReadDump, j.Debug.Conn.WriteDump)
if err != nil {
panic(err)
}
senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix)) senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix))

View File

@ -11,6 +11,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"time" "time"
"github.com/problame/go-streamrpc"
) )
var ConfigFileDefaultLocations []string = []string{ var ConfigFileDefaultLocations []string = []string{
@ -208,7 +209,7 @@ func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error)
} }
func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) { func parseConnect(i map[string]interface{}) (c streamrpc.Connecter, err error) {
t, err := extractStringField(i, "type", true) t, err := extractStringField(i, "type", true)
if err != nil { if err != nil {
@ -266,7 +267,7 @@ func parsePrunePolicy(v map[string]interface{}, willSeeBookmarks bool) (p PruneP
} }
} }
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) { func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p ListenerFactory, err error) {
t, err := extractStringField(v, "type", true) t, err := extractStringField(v, "type", true)
if err != nil { if err != nil {

View File

@ -4,7 +4,7 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-netssh" "github.com/problame/go-netssh"
"io" "net"
"path" "path"
) )
@ -30,9 +30,9 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
return return
} }
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) { func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
if err = PreparePrivateSockpath(f.sockpath); err != nil { if err := PreparePrivateSockpath(f.sockpath); err != nil {
return nil, err return nil, err
} }
@ -47,8 +47,16 @@ type StdinserverListener struct {
l *netssh.Listener l *netssh.Listener
} }
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) { func (l StdinserverListener) Addr() net.Addr {
return l.l.Accept() return netsshAddr{}
}
func (l StdinserverListener) Accept() (net.Conn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return netsshConnToNetConnAdatper{c}, nil
} }
func (l StdinserverListener) Close() (err error) { func (l StdinserverListener) Close() (err error) {

View File

@ -2,18 +2,19 @@ package util
import ( import (
"io" "io"
"net"
"os" "os"
) )
type ReadWriteCloserLogger struct { type NetConnLogger struct {
RWC io.ReadWriteCloser net.Conn
ReadFile *os.File ReadFile *os.File
WriteFile *os.File WriteFile *os.File
} }
func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string) (l *ReadWriteCloserLogger, err error) { func NewNetConnLogger(conn net.Conn, readlog, writelog string) (l *NetConnLogger, err error) {
l = &ReadWriteCloserLogger{ l = &NetConnLogger{
RWC: rwc, Conn: conn,
} }
flags := os.O_CREATE | os.O_WRONLY flags := os.O_CREATE | os.O_WRONLY
if readlog != "" { if readlog != "" {
@ -29,8 +30,8 @@ func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string)
return return
} }
func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) { func (c *NetConnLogger) Read(buf []byte) (n int, err error) {
n, err = c.RWC.Read(buf) n, err = c.Conn.Read(buf)
if c.WriteFile != nil { if c.WriteFile != nil {
if _, writeErr := c.ReadFile.Write(buf[0:n]); writeErr != nil { if _, writeErr := c.ReadFile.Write(buf[0:n]); writeErr != nil {
panic(writeErr) panic(writeErr)
@ -39,8 +40,8 @@ func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) {
return return
} }
func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) { func (c *NetConnLogger) Write(buf []byte) (n int, err error) {
n, err = c.RWC.Write(buf) n, err = c.Conn.Write(buf)
if c.ReadFile != nil { if c.ReadFile != nil {
if _, writeErr := c.WriteFile.Write(buf[0:n]); writeErr != nil { if _, writeErr := c.WriteFile.Write(buf[0:n]); writeErr != nil {
panic(writeErr) panic(writeErr)
@ -48,8 +49,8 @@ func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) {
} }
return return
} }
func (c *ReadWriteCloserLogger) Close() (err error) { func (c *NetConnLogger) Close() (err error) {
err = c.RWC.Close() err = c.Conn.Close()
if err != nil { if err != nil {
return return
} }