zrepl/daemon/transport/handshake.go

137 lines
3.8 KiB
Go

package transport
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"time"
"unicode/utf8"
)
type HandshakeMessage struct {
ProtocolVersion int
Extensions []string
}
func (m *HandshakeMessage) Encode() ([]byte, error) {
if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 {
return nil, fmt.Errorf("protocol version must be in [1, 9999]")
}
if len(m.Extensions) >= 9999 {
return nil, fmt.Errorf("protocol only supports [0, 9999] extensions")
}
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
var extensions strings.Builder
for i, ext := range m.Extensions {
if strings.ContainsAny(ext, "\n") {
return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i)
}
if !utf8.ValidString(ext) {
return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i)
}
extensions.WriteString(ext)
extensions.WriteString("\n")
}
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
m.ProtocolVersion, len(m.Extensions), extensions.String())
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
return []byte(withLen), nil
}
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
var lenAndSpace [11]byte
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
return err
}
if !utf8.Valid(lenAndSpace[:]) {
return fmt.Errorf("invalid start of handshake message: not valid UTF-8")
}
var followLen int
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
if n != 1 || err != nil {
return fmt.Errorf("could not parse handshake message length")
}
if followLen > maxLen {
return fmt.Errorf("handshake message length exceeds max length (%d vs %d)",
followLen, maxLen)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
if err != nil {
return err
}
var (
protoVersion, extensionCount int
)
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
&protoVersion, &extensionCount)
if n != 2 || err != nil {
return fmt.Errorf("could not parse handshake message: %s", err)
}
if protoVersion < 1 {
return fmt.Errorf("invalid protocol version %q", protoVersion)
}
m.ProtocolVersion = protoVersion
if extensionCount < 0 {
return fmt.Errorf("invalid extension count %q", extensionCount)
}
if extensionCount == 0 {
if buf.Len() != 0 {
return fmt.Errorf("unexpected data trailing after header")
}
m.Extensions = nil
return nil
}
s := buf.String()
if strings.Count(s, "\n") != extensionCount {
return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
}
exts := strings.Split(s, "\n")
if exts[len(exts)-1] != "" {
return fmt.Errorf("unexpected data trailing after last extension newline")
}
m.Extensions = exts[0:len(exts)-1]
return nil
}
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
// current protocol version is hardcoded here
return DoHandshakeVersion(conn, deadline, 1)
}
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
ours := HandshakeMessage{
ProtocolVersion: version,
Extensions: nil,
}
hsb, err := ours.Encode()
if err != nil {
return fmt.Errorf("could not encode protocol banner: %s", err)
}
conn.SetDeadline(deadline)
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
if err != nil {
return fmt.Errorf("could not send protocol banner: %s", err)
}
theirs := HandshakeMessage{}
if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant
return fmt.Errorf("could not decode protocol banner: %s", err)
}
if theirs.ProtocolVersion != ours.ProtocolVersion {
return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d",
ours.ProtocolVersion, theirs.ProtocolVersion)
}
// ignore extensions, we don't use them
return nil
}