zrepl/rpc/versionhandshake/versionhandshake.go
Christian Schwarz 2d8c3692ec rework resume token validation to allow resuming from raw sends of unencrypted datasets
Before this change, resuming from an unencrypted dataset with
send.raw=true specified wouldn't work with zrepl due to overly
restrictive resume token checking.

An initial PR to fix this was made in https://github.com/zrepl/zrepl/pull/503
but it didn't address the core of the problem.
The core of the problem was that zrepl assumed that if a resume token
contained `rawok=true, compressok=true`, the resulting send would be
encrypted. But if the sender dataset was unencrypted, such a resume would
actually result in an unencrypted send.
Which could be totally legitimate but zrepl failed to recognize that.

BACKGROUND
==========

The following snippets of OpenZFS code are insightful regarding how the
various ${X}ok values in the resume token are handled:

- 6c3c5fcfbe/module/zfs/dmu_send.c (L1947-L2012)
- 6c3c5fcfbe/module/zfs/dmu_recv.c (L877-L891)
- https://github.com/openzfs/zfs/blob/6c3c5fc/lib/libzfs/libzfs_sendrecv.c#L1663-L1672

Basically, some zfs send flags make the DMU send code set some DMU send
stream featureflags, although it's not a pure mapping, i.e, which DMU
send stream flags are used depends somewhat on the dataset (e.g., is it
encrypted or not, or, does it use zstd or not).

Then, the receiver looks at some (but not all) feature flags and maps
them to ${X}ok dataset zap attributes.

These are funnelled back to the sender 1:1 through the resume_token.

And the sender turns them into lzc flags.

As an example, let's look at zfs send --raw.
if the sender requests a raw send on an unencrypted dataset, the send
stream (and hence the resume token) will not have the raw stream
featureflag set, and hence the resume token will not have the rawok
field set. Instead, it will have compressok, embedok, and depending
on whether large blocks are present in the dataset, largeblockok set.

WHAT'S ZREPL'S ROLE IN THIS?
============================

zrepl provides a virtual encrypted sendflag that is like `raw`,
but further ensures that we only send encrypted datasets.

For any other resume token stuff, it shoudn't do any checking,
because it's a futile effort to keep up with ZFS send/recv features
that are orthogonal to encryption.

CHANGES MADE IN THIS COMMIT
===========================

- Rip out a bunch of needless checking that zrepl would do during
  planning. These checks were there to give better error messages,
  but actually, the error messages created by the endpoint.Sender.Send
  RPC upon send args validation failure are good enough.
- Add platformtests to validate all combinations of
  (Unencrypted/Encrypted FS) x (send.encrypted = true | false) x (send.raw = true | false)
  for cases both non-resuming and resuming send.

Additional manual testing done:
1. With zrepl 0.5, setup with unencrypted dataset, send.raw=true specified, no send.encrypted specified.
2. Observe that regular non-resuming send works, but resuming doesn't work.
3. Upgrade zrepl to this change.
4. Observe that both regular and resuming send works.

closes https://github.com/zrepl/zrepl/pull/613
2022-09-25 17:32:02 +02:00

201 lines
5.9 KiB
Go

// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener}
// to add an exchange of protocol version information on connection establishment.
//
// The protocol version information (banner) is plain text, thus making it
// easy to diagnose issues with standard tools.
package versionhandshake
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"time"
"unicode/utf8"
)
type HandshakeMessage struct {
ProtocolVersion int
Extensions []string
}
// A HandshakeError describes what went wrong during the handshake.
// It implements net.Error and is always temporary.
type HandshakeError struct {
msg string
// If not nil, the underlying IO error that caused the handshake to fail.
IOError error
isAcceptError bool
}
var _ net.Error = &HandshakeError{}
func (e HandshakeError) Error() string { return e.msg }
// Like with net.OpErr (Go issue 6163), a client failing to handshake
// should be a temporary Accept error toward the Listener .
func (e HandshakeError) Temporary() bool {
if e.isAcceptError {
return true
}
te, ok := e.IOError.(interface{ Temporary() bool })
return ok && te.Temporary()
}
// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value.
// Otherwise false.
func (e HandshakeError) Timeout() bool {
if neterr, ok := e.IOError.(net.Error); ok {
return neterr.Timeout()
}
return false
}
func hsErr(format string, args ...interface{}) *HandshakeError {
return &HandshakeError{msg: fmt.Sprintf(format, args...)}
}
func hsIOErr(err error, format string, args ...interface{}) *HandshakeError {
return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)}
}
// MaxProtocolVersion is the maximum allowed protocol version.
// This is a protocol constant, changing it may break the wire format.
const MaxProtocolVersion = 9999
// Only returns *HandshakeError as error.
func (m *HandshakeMessage) Encode() ([]byte, error) {
if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion))
}
if len(m.Extensions) >= MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion))
}
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
var extensions strings.Builder
for i, ext := range m.Extensions {
if strings.ContainsAny(ext, "\n") {
return nil, hsErr("Extension #%d contains forbidden newline character", i)
}
if !utf8.ValidString(ext) {
return nil, hsErr("Extension #%d is not valid UTF-8", i)
}
extensions.WriteString(ext)
extensions.WriteString("\n")
}
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
m.ProtocolVersion, len(m.Extensions), extensions.String())
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
return []byte(withLen), nil
}
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
var lenAndSpace [11]byte
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
return hsIOErr(err, "error reading protocol banner length: %s", err)
}
if !utf8.Valid(lenAndSpace[:]) {
return hsErr("invalid start of handshake message: not valid UTF-8")
}
var followLen int
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
if n != 1 || err != nil {
return hsErr("could not parse handshake message length")
}
if followLen > maxLen {
return hsErr("handshake message length exceeds max length (%d vs %d)",
followLen, maxLen)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
if err != nil {
return hsIOErr(err, "error reading protocol banner body: %s", err)
}
var (
protoVersion, extensionCount int
)
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
&protoVersion, &extensionCount)
if n != 2 || err != nil {
return hsErr("could not parse handshake message: %s", err)
}
if protoVersion < 1 {
return hsErr("invalid protocol version %q", protoVersion)
}
m.ProtocolVersion = protoVersion
if extensionCount < 0 {
return hsErr("invalid extension count %q", extensionCount)
}
if extensionCount == 0 {
if buf.Len() != 0 {
return hsErr("unexpected data trailing after header")
}
m.Extensions = nil
return nil
}
s := buf.String()
if strings.Count(s, "\n") != extensionCount {
return hsErr("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
}
exts := strings.Split(s, "\n")
if exts[len(exts)-1] != "" {
return hsErr("unexpected data trailing after last extension newline")
}
m.Extensions = exts[0 : len(exts)-1]
return nil
}
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) *HandshakeError {
// current protocol version is hardcoded here
return DoHandshakeVersion(conn, deadline, 7)
}
const HandshakeMessageMaxLen = 16 * 4096
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) (rErr *HandshakeError) {
ours := HandshakeMessage{
ProtocolVersion: version,
Extensions: nil,
}
hsb, err := ours.Encode()
if err != nil {
return hsErr("could not encode protocol banner: %s", err)
}
err = conn.SetDeadline(deadline)
if err != nil {
return hsErr("could not set deadline for protocol banner handshake: %s", err)
}
defer func() {
if rErr != nil {
return
}
err := conn.SetDeadline(time.Time{})
if err != nil {
rErr = hsErr("could not reset deadline after protocol banner handshake: %s", err)
}
}()
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
if err != nil {
return hsErr("could not send protocol banner: %s", err)
}
theirs := HandshakeMessage{}
if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil {
return hsErr("could not decode protocol banner: %s", err)
}
if theirs.ProtocolVersion != ours.ProtocolVersion {
return hsErr("protocol versions do not match: ours is %d, theirs is %d",
ours.ProtocolVersion, theirs.ProtocolVersion)
}
// ignore extensions, we don't use them
return nil
}