zrepl/internal/rpc/versionhandshake/versionhandshake.go
2024-10-18 19:21:17 +02:00

240 lines
8.0 KiB
Go

// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener}
// to add an exchange of protocol version information on connection establishment.
//
// The protocol version information (banner) is plain text, thus making it
// easy to diagnose issues with standard tools.
package versionhandshake
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"time"
"unicode/utf8"
)
type HandshakeMessage struct {
ProtocolVersion int
Extensions []string
}
// A HandshakeError describes what went wrong during the handshake.
// It implements net.Error and is always temporary.
type HandshakeError struct {
msg string
// If not nil, the underlying IO error that caused the handshake to fail.
IOError error
isAcceptError bool
}
var _ net.Error = &HandshakeError{}
func (e HandshakeError) Error() string { return e.msg }
// When a net.Listener.Accept() returns an error, the server must
// decide whether to retry calling Accept() or not.
// On some platforms (e.g., Linux), Accept() can return errors
// related to the specific protocol connection that was supposed
// to be returned as asocket FD. Obviously, we want to ignore,
// maybe log, those errors and retry Accept() immediately to
// serve other connections.
// But there are also conditions where we get Accept() errors because
// the process has run out of file descriptors. In that case, retrying
// won't help. We need to close some file descriptor to make progress.
// Note that there could be lots of open file descriptors because we
// have accepted, and not yet closed, lots of connections in the past.
// And then, of course there can be errors where we just want
// to return, e.g., if there's a programming error and we're getting
// an EBADFD or whatever.
//
// So, the serve loops in net/http.Server.Serve() or gRPC's server.Serve()
// must inspect the error and decide what to do.
// The vehicle for this is the
//
// interface { Temporary() bool }
//
// Behavior in both of the aforementioned Serve() loops:
//
// - if the error doesn't implement the interface, stop serving and return
// - `Temporary() == true`: retry with back-off
// - `Temporary() == false`: stop serving and return
//
// So, to make this package's HandshakeListener work with these
// Serve() loops, we return Temporary() == true if the handshake fails.
// In the aforementioned categories, that's the case of a per-connection
// protocol error.
//
// Note: the net.Error interface has deprecated the Temporary() method
// in go.dev/issue/45729, but there is no replacement for users of .Accept().
// Existing users of .Accept() continue to check for the interface.
// So, we need to continue supporting Temporary() until there's a different
// mechanism for serve loops to decide whether to retry or not.
// The following mailing list post proposes to eliminate the retries
// completely, but it seems like the effort has stalled.
// https://groups.google.com/g/golang-nuts/c/-JcZzOkyqYI/m/xwaZzjCgAwAJ
func (e HandshakeError) Temporary() bool {
if e.isAcceptError {
return true
}
te, ok := e.IOError.(interface{ Temporary() bool })
return ok && te.Temporary()
}
// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value.
// Otherwise false.
func (e HandshakeError) Timeout() bool {
if neterr, ok := e.IOError.(net.Error); ok {
return neterr.Timeout()
}
return false
}
func hsErr(msg string) *HandshakeError {
return &HandshakeError{msg: msg}
}
func hsIOErr(err error, format string, args ...interface{}) *HandshakeError {
return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)}
}
// MaxProtocolVersion is the maximum allowed protocol version.
// This is a protocol constant, changing it may break the wire format.
const MaxProtocolVersion = 9999
// Only returns *HandshakeError as error.
func (m *HandshakeMessage) Encode() ([]byte, error) {
if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion))
}
if len(m.Extensions) >= MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion))
}
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
var extensions strings.Builder
for i, ext := range m.Extensions {
if strings.ContainsAny(ext, "\n") {
return nil, hsErr(fmt.Sprintf("Extension #%d contains forbidden newline character", i))
}
if !utf8.ValidString(ext) {
return nil, hsErr(fmt.Sprintf("Extension #%d is not valid UTF-8", i))
}
extensions.WriteString(ext)
extensions.WriteString("\n")
}
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
m.ProtocolVersion, len(m.Extensions), extensions.String())
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
return []byte(withLen), nil
}
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
var lenAndSpace [11]byte
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
return hsIOErr(err, "error reading protocol banner length: %s", err)
}
if !utf8.Valid(lenAndSpace[:]) {
return hsErr("invalid start of handshake message: not valid UTF-8")
}
var followLen int
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
if n != 1 || err != nil {
return hsErr("could not parse handshake message length")
}
if followLen > maxLen {
return hsErr(fmt.Sprintf("handshake message length exceeds max length (%d vs %d)",
followLen, maxLen))
}
var buf bytes.Buffer
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
if err != nil {
return hsIOErr(err, "error reading protocol banner body: %s", err)
}
var (
protoVersion, extensionCount int
)
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
&protoVersion, &extensionCount)
if n != 2 || err != nil {
return hsErr(fmt.Sprintf("could not parse handshake message: %s", err))
}
if protoVersion < 1 {
return hsErr(fmt.Sprintf("invalid protocol version %q", protoVersion))
}
m.ProtocolVersion = protoVersion
if extensionCount < 0 {
return hsErr(fmt.Sprintf("invalid extension count %q", extensionCount))
}
if extensionCount == 0 {
if buf.Len() != 0 {
return hsErr("unexpected data trailing after header")
}
m.Extensions = nil
return nil
}
s := buf.String()
if strings.Count(s, "\n") != extensionCount {
return hsErr(fmt.Sprintf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount))
}
exts := strings.Split(s, "\n")
if exts[len(exts)-1] != "" {
return hsErr("unexpected data trailing after last extension newline")
}
m.Extensions = exts[0 : len(exts)-1]
return nil
}
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) *HandshakeError {
// current protocol version is hardcoded here
return DoHandshakeVersion(conn, deadline, 7)
}
const HandshakeMessageMaxLen = 16 * 4096
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) (rErr *HandshakeError) {
ours := HandshakeMessage{
ProtocolVersion: version,
Extensions: nil,
}
hsb, err := ours.Encode()
if err != nil {
return hsErr(fmt.Sprintf("could not encode protocol banner: %s", err))
}
err = conn.SetDeadline(deadline)
if err != nil {
return hsErr(fmt.Sprintf("could not set deadline for protocol banner handshake: %s", err))
}
defer func() {
if rErr != nil {
return
}
err := conn.SetDeadline(time.Time{})
if err != nil {
rErr = hsErr(fmt.Sprintf("could not reset deadline after protocol banner handshake: %s", err))
}
}()
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
if err != nil {
return hsErr(fmt.Sprintf("could not send protocol banner: %s", err))
}
theirs := HandshakeMessage{}
if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil {
return hsErr(fmt.Sprintf("could not decode protocol banner: %s", err))
}
if theirs.ProtocolVersion != ours.ProtocolVersion {
return hsErr(fmt.Sprintf("protocol versions do not match: ours is %d, theirs is %d",
ours.ProtocolVersion, theirs.ProtocolVersion))
}
// ignore extensions, we don't use them
return nil
}