mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-27 00:30:40 +01:00
streamrpc now requires net.Conn => use it instead of rwc everywhere
This commit is contained in:
parent
1826535e6f
commit
a0b320bfeb
76
cmd/adaptors.go
Normal file
76
cmd/adaptors.go
Normal file
@ -0,0 +1,76 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
)
|
||||
|
||||
type logNetConnConnecter struct {
|
||||
streamrpc.Connecter
|
||||
ReadDump, WriteDump string
|
||||
}
|
||||
|
||||
var _ streamrpc.Connecter = logNetConnConnecter{}
|
||||
|
||||
func (l logNetConnConnecter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
conn, err := l.Connecter.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
|
||||
}
|
||||
|
||||
type logListenerFactory struct {
|
||||
ListenerFactory
|
||||
ReadDump, WriteDump string
|
||||
}
|
||||
|
||||
var _ ListenerFactory = logListenerFactory{}
|
||||
|
||||
type logListener struct {
|
||||
net.Listener
|
||||
ReadDump, WriteDump string
|
||||
}
|
||||
|
||||
var _ net.Listener = logListener{}
|
||||
|
||||
func (m logListenerFactory) Listen() (net.Listener, error) {
|
||||
l, err := m.ListenerFactory.Listen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logListener{l, m.ReadDump, m.WriteDump}, nil
|
||||
}
|
||||
|
||||
func (l logListener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return util.NewNetConnLogger(conn, l.ReadDump, l.WriteDump)
|
||||
}
|
||||
|
||||
|
||||
type netsshAddr struct{}
|
||||
|
||||
func (netsshAddr) Network() string { return "netssh" }
|
||||
func (netsshAddr) String() string { return "???" }
|
||||
|
||||
type netsshConnToNetConnAdatper struct {
|
||||
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
||||
}
|
||||
|
||||
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }
|
@ -1,7 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
@ -43,16 +43,8 @@ type JobDebugSettings struct {
|
||||
}
|
||||
}
|
||||
|
||||
type RWCConnecter interface {
|
||||
Connect() (io.ReadWriteCloser, error)
|
||||
}
|
||||
type AuthenticatedChannelListenerFactory interface {
|
||||
Listen() (AuthenticatedChannelListener, error)
|
||||
}
|
||||
|
||||
type AuthenticatedChannelListener interface {
|
||||
Accept() (ch io.ReadWriteCloser, err error)
|
||||
Close() (err error)
|
||||
type ListenerFactory interface {
|
||||
Listen() (net.Listener, error)
|
||||
}
|
||||
|
||||
type SSHStdinServerConnectDescr struct {
|
||||
|
@ -2,13 +2,14 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"context"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-netssh"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -24,6 +25,8 @@ type SSHStdinserverConnecter struct {
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
var _ streamrpc.Connecter = &SSHStdinserverConnecter{}
|
||||
|
||||
func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverConnecter, err error) {
|
||||
|
||||
c = &SSHStdinserverConnecter{}
|
||||
@ -46,21 +49,28 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
|
||||
|
||||
}
|
||||
|
||||
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
|
||||
type netsshConnToConn struct { *netssh.SSHConn }
|
||||
|
||||
var _ net.Conn = netsshConnToConn{}
|
||||
|
||||
func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil }
|
||||
func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil }
|
||||
func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil }
|
||||
|
||||
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) {
|
||||
|
||||
var endpoint netssh.Endpoint
|
||||
if err = copier.Copy(&endpoint, c); err != nil {
|
||||
if err := copier.Copy(&endpoint, c); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
var dialCtx context.Context
|
||||
dialCtx, dialCancel := context.WithTimeout(context.TODO(), c.dialTimeout) // context.TODO tied to error handling below
|
||||
dialCtx, dialCancel := context.WithTimeout(dialCtx, c.dialTimeout) // context.TODO tied to error handling below
|
||||
defer dialCancel()
|
||||
if rwc, err = netssh.Dial(dialCtx, endpoint); err != nil {
|
||||
nconn, err := netssh.Dial(dialCtx, endpoint)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
err = errors.Errorf("dial_timeout of %s exceeded", c.dialTimeout)
|
||||
}
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
return netsshConnToConn{nconn}, nil
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
|
||||
type PullJob struct {
|
||||
Name string
|
||||
Connect RWCConnecter
|
||||
Connect streamrpc.Connecter
|
||||
Interval time.Duration
|
||||
Mapping *DatasetMapFilter
|
||||
// constructed from mapping during parsing
|
||||
@ -90,6 +90,15 @@ func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j
|
||||
return
|
||||
}
|
||||
|
||||
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
|
||||
logConnecter := logNetConnConnecter{
|
||||
Connecter: j.Connect,
|
||||
ReadDump: j.Debug.Conn.ReadDump,
|
||||
WriteDump: j.Debug.Conn.WriteDump,
|
||||
}
|
||||
j.Connect = logConnecter
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -132,56 +141,12 @@ var STREAMRPC_CONFIG = &streamrpc.ConnConfig{ // FIXME oversight and configurabi
|
||||
RxStructuredMaxLen: 4096 * 4096,
|
||||
RxStreamMaxChunkSize: 4096 * 4096,
|
||||
TxChunkSize: 4096 * 4096,
|
||||
}
|
||||
|
||||
type streamrpcRWCToNetConnAdatper struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) RemoteAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
type streamrpcRWCConnecterToNetConnAdapter struct {
|
||||
RWCConnecter
|
||||
ReadDump, WriteDump string
|
||||
}
|
||||
|
||||
func (s streamrpcRWCConnecterToNetConnAdapter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
rwc, err := s.RWCConnecter.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rwc, err = util.NewReadWriteCloserLogger(rwc, s.ReadDump, s.WriteDump)
|
||||
if err != nil {
|
||||
rwc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return streamrpcRWCToNetConnAdatper{rwc}, nil
|
||||
}
|
||||
|
||||
type tcpConnecter struct {
|
||||
d net.Dialer
|
||||
}
|
||||
|
||||
func (t *tcpConnecter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
return t.d.DialContext(ctx, "tcp", "192.168.122.128:8888")
|
||||
RxTimeout: streamrpc.Timeout{
|
||||
Progress: 10*time.Second,
|
||||
},
|
||||
TxTimeout: streamrpc.Timeout{
|
||||
Progress: 10*time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
func (j *PullJob) doRun(ctx context.Context) {
|
||||
@ -189,25 +154,12 @@ func (j *PullJob) doRun(ctx context.Context) {
|
||||
j.task.Enter("run")
|
||||
defer j.task.Finish()
|
||||
|
||||
//connecter := streamrpcRWCConnecterToNetConnAdapter{
|
||||
// RWCConnecter: j.Connect,
|
||||
// ReadDump: j.Debug.Conn.ReadDump,
|
||||
// WriteDump: j.Debug.Conn.WriteDump,
|
||||
//}
|
||||
|
||||
// FIXME
|
||||
connecter := &tcpConnecter{net.Dialer{
|
||||
Timeout: 2*time.Second,
|
||||
}}
|
||||
|
||||
clientConf := &streamrpc.ClientConfig{
|
||||
MaxConnectAttempts: 5, // FIXME
|
||||
ReconnectBackoffBase: 1*time.Second,
|
||||
ReconnectBackoffFactor: 2,
|
||||
ConnConfig: STREAMRPC_CONFIG,
|
||||
}
|
||||
|
||||
client, err := streamrpc.NewClient(connecter, clientConf)
|
||||
client, err := streamrpc.NewClient(j.Connect, clientConf)
|
||||
defer client.Close()
|
||||
|
||||
j.task.Enter("pull")
|
||||
|
@ -2,19 +2,17 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SourceJob struct {
|
||||
Name string
|
||||
Serve AuthenticatedChannelListenerFactory
|
||||
Serve ListenerFactory
|
||||
Filesystems *DatasetMapFilter
|
||||
SnapshotPrefix string
|
||||
Interval time.Duration
|
||||
@ -70,6 +68,15 @@ func parseSourceJob(c JobParsingContext, name string, i map[string]interface{})
|
||||
return
|
||||
}
|
||||
|
||||
if j.Debug.Conn.ReadDump != "" || j.Debug.Conn.WriteDump != "" {
|
||||
logServe := logListenerFactory{
|
||||
ListenerFactory: j.Serve,
|
||||
ReadDump: j.Debug.Conn.ReadDump,
|
||||
WriteDump: j.Debug.Conn.WriteDump,
|
||||
}
|
||||
j.Serve = logServe
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -139,19 +146,17 @@ func (j *SourceJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Pru
|
||||
|
||||
func (j *SourceJob) serve(ctx context.Context, task *Task) {
|
||||
|
||||
//listener, err := j.Serve.Listen()
|
||||
// FIXME
|
||||
listener, err := net.Listen("tcp", "192.168.122.128:8888")
|
||||
listener, err := j.Serve.Listen()
|
||||
if err != nil {
|
||||
task.Log().WithError(err).Error("error listening")
|
||||
return
|
||||
}
|
||||
|
||||
type rwcChanMsg struct {
|
||||
rwc io.ReadWriteCloser
|
||||
err error
|
||||
type connChanMsg struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
rwcChan := make(chan rwcChanMsg)
|
||||
connChan := make(chan connChanMsg)
|
||||
|
||||
// Serve connections until interrupted or error
|
||||
outer:
|
||||
@ -160,23 +165,23 @@ outer:
|
||||
go func() {
|
||||
rwc, err := listener.Accept()
|
||||
if err != nil {
|
||||
rwcChan <- rwcChanMsg{rwc, err}
|
||||
close(rwcChan)
|
||||
connChan <- connChanMsg{rwc, err}
|
||||
close(connChan)
|
||||
return
|
||||
}
|
||||
rwcChan <- rwcChanMsg{rwc, err}
|
||||
connChan <- connChanMsg{rwc, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
case rwcMsg := <-rwcChan:
|
||||
case rwcMsg := <-connChan:
|
||||
|
||||
if rwcMsg.err != nil {
|
||||
task.Log().WithError(err).Error("error accepting connection")
|
||||
break outer
|
||||
}
|
||||
|
||||
j.handleConnection(rwcMsg.rwc, task)
|
||||
j.handleConnection(rwcMsg.conn, task)
|
||||
|
||||
case <-ctx.Done():
|
||||
task.Log().WithError(ctx.Err()).Info("context")
|
||||
@ -197,17 +202,13 @@ outer:
|
||||
|
||||
}
|
||||
|
||||
func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) {
|
||||
func (j *SourceJob) handleConnection(conn net.Conn, task *Task) {
|
||||
|
||||
task.Enter("handle_connection")
|
||||
defer task.Finish()
|
||||
|
||||
task.Log().Info("handling client connection")
|
||||
|
||||
rwc, err := util.NewReadWriteCloserLogger(rwc, j.Debug.Conn.ReadDump, j.Debug.Conn.WriteDump)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix))
|
||||
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
"github.com/problame/go-streamrpc"
|
||||
)
|
||||
|
||||
var ConfigFileDefaultLocations []string = []string{
|
||||
@ -208,7 +209,7 @@ func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error)
|
||||
|
||||
}
|
||||
|
||||
func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) {
|
||||
func parseConnect(i map[string]interface{}) (c streamrpc.Connecter, err error) {
|
||||
|
||||
t, err := extractStringField(i, "type", true)
|
||||
if err != nil {
|
||||
@ -266,7 +267,7 @@ func parsePrunePolicy(v map[string]interface{}, willSeeBookmarks bool) (p PruneP
|
||||
}
|
||||
}
|
||||
|
||||
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
|
||||
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p ListenerFactory, err error) {
|
||||
|
||||
t, err := extractStringField(v, "type", true)
|
||||
if err != nil {
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-netssh"
|
||||
"io"
|
||||
"net"
|
||||
"path"
|
||||
)
|
||||
|
||||
@ -30,9 +30,9 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
|
||||
return
|
||||
}
|
||||
|
||||
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
||||
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
|
||||
|
||||
if err = PreparePrivateSockpath(f.sockpath); err != nil {
|
||||
if err := PreparePrivateSockpath(f.sockpath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -47,8 +47,16 @@ type StdinserverListener struct {
|
||||
l *netssh.Listener
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
||||
return l.l.Accept()
|
||||
func (l StdinserverListener) Addr() net.Addr {
|
||||
return netsshAddr{}
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Accept() (net.Conn, error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return netsshConnToNetConnAdatper{c}, nil
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Close() (err error) {
|
||||
|
23
util/io.go
23
util/io.go
@ -2,18 +2,19 @@ package util
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type ReadWriteCloserLogger struct {
|
||||
RWC io.ReadWriteCloser
|
||||
type NetConnLogger struct {
|
||||
net.Conn
|
||||
ReadFile *os.File
|
||||
WriteFile *os.File
|
||||
}
|
||||
|
||||
func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string) (l *ReadWriteCloserLogger, err error) {
|
||||
l = &ReadWriteCloserLogger{
|
||||
RWC: rwc,
|
||||
func NewNetConnLogger(conn net.Conn, readlog, writelog string) (l *NetConnLogger, err error) {
|
||||
l = &NetConnLogger{
|
||||
Conn: conn,
|
||||
}
|
||||
flags := os.O_CREATE | os.O_WRONLY
|
||||
if readlog != "" {
|
||||
@ -29,8 +30,8 @@ func NewReadWriteCloserLogger(rwc io.ReadWriteCloser, readlog, writelog string)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) {
|
||||
n, err = c.RWC.Read(buf)
|
||||
func (c *NetConnLogger) Read(buf []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(buf)
|
||||
if c.WriteFile != nil {
|
||||
if _, writeErr := c.ReadFile.Write(buf[0:n]); writeErr != nil {
|
||||
panic(writeErr)
|
||||
@ -39,8 +40,8 @@ func (c *ReadWriteCloserLogger) Read(buf []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) {
|
||||
n, err = c.RWC.Write(buf)
|
||||
func (c *NetConnLogger) Write(buf []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(buf)
|
||||
if c.ReadFile != nil {
|
||||
if _, writeErr := c.WriteFile.Write(buf[0:n]); writeErr != nil {
|
||||
panic(writeErr)
|
||||
@ -48,8 +49,8 @@ func (c *ReadWriteCloserLogger) Write(buf []byte) (n int, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
func (c *ReadWriteCloserLogger) Close() (err error) {
|
||||
err = c.RWC.Close()
|
||||
func (c *NetConnLogger) Close() (err error) {
|
||||
err = c.Conn.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user