mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-24 17:35:01 +01:00
implement transport protocol handshake (even before streamrpc handshake)
This commit is contained in:
parent
be962998ba
commit
1fb59c953a
@ -1,12 +1,39 @@
|
||||
package connecter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||
"github.com/zrepl/zrepl/daemon/transport"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
type HandshakeConnecter struct {
|
||||
connecter streamrpc.Connecter
|
||||
}
|
||||
|
||||
func (c HandshakeConnecter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
conn, err := c.connecter.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(10 * time.Second) // FIXME constant
|
||||
}
|
||||
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error) {
|
||||
var (
|
||||
connecter streamrpc.Connecter
|
||||
@ -41,6 +68,9 @@ func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error)
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connecter = HandshakeConnecter{connecter}
|
||||
|
||||
return &ClientFactory{connecter: connecter, config: &config}, nil
|
||||
}
|
||||
|
||||
|
136
daemon/transport/handshake.go
Normal file
136
daemon/transport/handshake.go
Normal file
@ -0,0 +1,136 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type HandshakeMessage struct {
|
||||
ProtocolVersion int
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
func (m *HandshakeMessage) Encode() ([]byte, error) {
|
||||
if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 {
|
||||
return nil, fmt.Errorf("protocol version must be in [1, 9999]")
|
||||
}
|
||||
if len(m.Extensions) >= 9999 {
|
||||
return nil, fmt.Errorf("protocol only supports [0, 9999] extensions")
|
||||
}
|
||||
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
|
||||
var extensions strings.Builder
|
||||
for i, ext := range m.Extensions {
|
||||
if strings.ContainsAny(ext, "\n") {
|
||||
return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i)
|
||||
}
|
||||
if !utf8.ValidString(ext) {
|
||||
return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i)
|
||||
}
|
||||
extensions.WriteString(ext)
|
||||
extensions.WriteString("\n")
|
||||
}
|
||||
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
|
||||
m.ProtocolVersion, len(m.Extensions), extensions.String())
|
||||
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
|
||||
return []byte(withLen), nil
|
||||
}
|
||||
|
||||
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
|
||||
var lenAndSpace [11]byte
|
||||
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if !utf8.Valid(lenAndSpace[:]) {
|
||||
return fmt.Errorf("invalid start of handshake message: not valid UTF-8")
|
||||
}
|
||||
var followLen int
|
||||
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
|
||||
if n != 1 || err != nil {
|
||||
return fmt.Errorf("could not parse handshake message length")
|
||||
}
|
||||
if followLen > maxLen {
|
||||
return fmt.Errorf("handshake message length exceeds max length (%d vs %d)",
|
||||
followLen, maxLen)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
protoVersion, extensionCount int
|
||||
)
|
||||
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
|
||||
&protoVersion, &extensionCount)
|
||||
if n != 2 || err != nil {
|
||||
return fmt.Errorf("could not parse handshake message: %s", err)
|
||||
}
|
||||
if protoVersion < 1 {
|
||||
return fmt.Errorf("invalid protocol version %q", protoVersion)
|
||||
}
|
||||
m.ProtocolVersion = protoVersion
|
||||
|
||||
if extensionCount < 0 {
|
||||
return fmt.Errorf("invalid extension count %q", extensionCount)
|
||||
}
|
||||
if extensionCount == 0 {
|
||||
if buf.Len() != 0 {
|
||||
return fmt.Errorf("unexpected data trailing after header")
|
||||
}
|
||||
m.Extensions = nil
|
||||
return nil
|
||||
}
|
||||
s := buf.String()
|
||||
if strings.Count(s, "\n") != extensionCount {
|
||||
return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
|
||||
}
|
||||
exts := strings.Split(s, "\n")
|
||||
if exts[len(exts)-1] != "" {
|
||||
return fmt.Errorf("unexpected data trailing after last extension newline")
|
||||
}
|
||||
m.Extensions = exts[0:len(exts)-1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
|
||||
// current protocol version is hardcoded here
|
||||
return DoHandshakeVersion(conn, deadline, 1)
|
||||
}
|
||||
|
||||
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
|
||||
ours := HandshakeMessage{
|
||||
ProtocolVersion: version,
|
||||
Extensions: nil,
|
||||
}
|
||||
hsb, err := ours.Encode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not encode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
conn.SetDeadline(deadline)
|
||||
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not send protocol banner: %s", err)
|
||||
}
|
||||
|
||||
theirs := HandshakeMessage{}
|
||||
if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant
|
||||
return fmt.Errorf("could not decode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
if theirs.ProtocolVersion != ours.ProtocolVersion {
|
||||
return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d",
|
||||
ours.ProtocolVersion, theirs.ProtocolVersion)
|
||||
}
|
||||
// ignore extensions, we don't use them
|
||||
|
||||
return nil
|
||||
}
|
119
daemon/transport/handshake_test.go
Normal file
119
daemon/transport/handshake_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zrepl/zrepl/util/socketpair"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHandshakeMessage_Encode(t *testing.T) {
|
||||
|
||||
msg := HandshakeMessage{
|
||||
ProtocolVersion: 2342,
|
||||
}
|
||||
|
||||
encB, err := msg.Encode()
|
||||
require.NoError(t, err)
|
||||
enc := string(encB)
|
||||
t.Logf("enc: %s", enc)
|
||||
|
||||
|
||||
|
||||
assert.False(t, strings.ContainsAny(enc[0:10], " "))
|
||||
assert.True(t, enc[10] == ' ')
|
||||
|
||||
var (
|
||||
headerlen, protoversion, extensionCount int
|
||||
)
|
||||
n, err := fmt.Sscanf(enc, "%010d ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n",
|
||||
&headerlen, &protoversion, &extensionCount)
|
||||
if n != 3 || (err != nil && err != io.EOF) {
|
||||
t.Fatalf("%v %v", n, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 2342, protoversion)
|
||||
assert.Equal(t, 0, extensionCount)
|
||||
assert.Equal(t, len(enc)-11, headerlen)
|
||||
|
||||
}
|
||||
|
||||
func TestHandshakeMessage_Encode_InvalidProtocolVersion(t *testing.T) {
|
||||
|
||||
for _, pv := range []int{-1, 0, 10000, 10001} {
|
||||
t.Logf("testing invalid protocol version = %v", pv)
|
||||
msg := HandshakeMessage{
|
||||
ProtocolVersion: pv,
|
||||
}
|
||||
b, err := msg.Encode()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, b)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandshakeMessage_DecodeReader(t *testing.T) {
|
||||
|
||||
in := HandshakeMessage{
|
||||
2342,
|
||||
[]string{"foo", "bar 2342"},
|
||||
}
|
||||
|
||||
enc, err := in.Encode()
|
||||
require.NoError(t, err)
|
||||
|
||||
out := HandshakeMessage{}
|
||||
err = out.DecodeReader(bytes.NewReader([]byte(enc)), 4 * 4096)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2342, out.ProtocolVersion)
|
||||
assert.Equal(t, 2, len(out.Extensions))
|
||||
assert.Equal(t, "foo", out.Extensions[0])
|
||||
assert.Equal(t, "bar 2342", out.Extensions[1])
|
||||
|
||||
}
|
||||
|
||||
func TestDoHandshakeVersion_ErrorOnDifferentVersions(t *testing.T) {
|
||||
srv, client, err := socketpair.SocketPair()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srv.Close()
|
||||
defer client.Close()
|
||||
|
||||
srvErrCh := make(chan error)
|
||||
go func() {
|
||||
srvErrCh <- DoHandshakeVersion(srv, time.Now().Add(2*time.Second), 1)
|
||||
}()
|
||||
err = DoHandshakeVersion(client, time.Now().Add(2*time.Second), 2)
|
||||
t.Log(err)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "version"))
|
||||
|
||||
srvErr := <-srvErrCh
|
||||
t.Log(srvErr)
|
||||
assert.Error(t, srvErr)
|
||||
assert.True(t, strings.Contains(srvErr.Error(), "version"))
|
||||
}
|
||||
|
||||
func TestDoHandshakeCurrentVersion(t *testing.T) {
|
||||
srv, client, err := socketpair.SocketPair()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srv.Close()
|
||||
defer client.Close()
|
||||
|
||||
srvErrCh := make(chan error)
|
||||
go func() {
|
||||
srvErrCh <- DoHandshakeVersion(srv, time.Now().Add(2*time.Second), 1)
|
||||
}()
|
||||
err = DoHandshakeVersion(client, time.Now().Add(2*time.Second), 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, <-srvErrCh)
|
||||
|
||||
}
|
@ -3,12 +3,14 @@ package serve
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/transport"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"context"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
@ -71,6 +73,42 @@ type ListenerFactory interface {
|
||||
Listen() (AuthenticatedListener, error)
|
||||
}
|
||||
|
||||
type HandshakeListenerFactory struct {
|
||||
lf ListenerFactory
|
||||
}
|
||||
|
||||
func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := lf.lf.Listen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return HandshakeListener{l}, nil
|
||||
}
|
||||
|
||||
type HandshakeListener struct {
|
||||
l AuthenticatedListener
|
||||
}
|
||||
|
||||
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
|
||||
|
||||
func (l HandshakeListener) Close() error { return l.l.Close() }
|
||||
|
||||
func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
conn, err := l.l.Accept(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(10*time.Second) // FIXME constant
|
||||
}
|
||||
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
||||
|
||||
var (
|
||||
@ -100,6 +138,8 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf
|
||||
return nil, nil, rpcErr
|
||||
}
|
||||
|
||||
lf = HandshakeListenerFactory{lf}
|
||||
|
||||
return lf, conf, nil
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user