Merge branch 'master' into fix_peer_cert_chains

This commit is contained in:
Christian Schwarz 2019-03-15 16:34:21 +01:00 committed by GitHub
commit 5595cff6a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
106 changed files with 6669 additions and 1592 deletions

View File

@ -16,25 +16,6 @@ matrix:
zrepl_build make vendordeps release zrepl_build make vendordeps release
# all go entries vary only by go version # all go entries vary only by go version
- language: go
go:
- "1.10"
go_import_path: github.com/zrepl/zrepl
before_install:
- wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip
- echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c
- sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip
- ./lazy.sh godep
- make vendordeps
script:
- make
- make vet
- make test
- go test ./...
- make artifacts/zrepl-freebsd-amd64
- make artifacts/zrepl-linux-amd64
- make artifacts/zrepl-darwin-amd64
- language: go - language: go
go: go:
- "1.11" - "1.11"
@ -49,7 +30,24 @@ matrix:
- make - make
- make vet - make vet
- make test - make test
- go test ./... - make artifacts/zrepl-freebsd-amd64
- make artifacts/zrepl-linux-amd64
- make artifacts/zrepl-darwin-amd64
- language: go
go:
- "master"
go_import_path: github.com/zrepl/zrepl
before_install:
- wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip
- echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c
- sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip
- ./lazy.sh godep
- make vendordeps
script:
- make
- make vet
- make test
- make artifacts/zrepl-freebsd-amd64 - make artifacts/zrepl-freebsd-amd64
- make artifacts/zrepl-linux-amd64 - make artifacts/zrepl-linux-amd64
- make artifacts/zrepl-darwin-amd64 - make artifacts/zrepl-darwin-amd64

134
Gopkg.lock generated
View File

@ -80,6 +80,10 @@
"protoc-gen-go/generator/internal/remap", "protoc-gen-go/generator/internal/remap",
"protoc-gen-go/grpc", "protoc-gen-go/grpc",
"protoc-gen-go/plugin", "protoc-gen-go/plugin",
"ptypes",
"ptypes/any",
"ptypes/duration",
"ptypes/timestamp",
] ]
pruneopts = "" pruneopts = ""
revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5"
@ -173,6 +177,14 @@
revision = "645ef00459ed84a119197bfb8d8205042c6df63d" revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
version = "v0.8.0" version = "v0.8.0"
[[projects]]
digest = "1:1cbc6b98173422a756ae79e485952cb37a0a460c710541c75d3e9961c5a60782"
name = "github.com/pkg/profile"
packages = ["."]
pruneopts = ""
revision = "5b67d428864e92711fcbd2f8629456121a56d91f"
version = "v1.2.1"
[[projects]] [[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib" name = "github.com/pmezard/go-difflib"
@ -183,30 +195,11 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:25559b520313b941b1395cd5d5ee66086b27dc15a1391c0f2aad29d5c2321f4b" digest = "1:fa72f780ae3b4820ed12cef7a034291ab10d83e2da4ab5ba81afa44d5bf3a529"
name = "github.com/problame/go-netssh" name = "github.com/problame/go-netssh"
packages = ["."] packages = ["."]
pruneopts = "" pruneopts = ""
revision = "c56ad38d2c91397ad3c8dd9443d7448e328a9e9e" revision = "09d6bc45d284784cb3e5aaa1998513f37eb19cc6"
[[projects]]
branch = "master"
digest = "1:8c63c44f018bd52b03ebad65c9df26aabbc6793138e421df1c8c84c285a45bc6"
name = "github.com/problame/go-rwccmd"
packages = ["."]
pruneopts = ""
revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7"
[[projects]]
digest = "1:1bcbb0a7ad8d3392d446eb583ae5415ff987838a8f7331a36877789be20667e6"
name = "github.com/problame/go-streamrpc"
packages = [
".",
"internal/pdu",
]
pruneopts = ""
revision = "d5d111e014342fe1c37f0b71cc37ec5f2afdfd13"
version = "v0.5"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -279,31 +272,66 @@
revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0"
version = "v1.1.4" version = "v1.1.4"
[[projects]]
digest = "1:f80053a92d9ac874ad97d665d3005c1118ed66e9e88401361dc32defe6bef21c"
name = "github.com/theckman/goconstraint"
packages = ["go1.11/gte"]
pruneopts = ""
revision = "93babf24513d0e8277635da8169fcc5a46ae3f6a"
version = "v1.11.0"
[[projects]] [[projects]]
branch = "v2" branch = "v2"
digest = "1:9d92186f609a73744232323416ddafd56fae67cb552162cc190ab903e36900dd" digest = "1:6b8a6afafde7ed31cd0c577ba40d88ce39e8f1c5eb76d7836be7d5b74f1c534a"
name = "github.com/zrepl/yaml-config" name = "github.com/zrepl/yaml-config"
packages = ["."] packages = ["."]
pruneopts = "" pruneopts = ""
revision = "af27d27978ad95808723a62d87557d63c3ff0605" revision = "08227ad854131f7dfcdfb12579fb73dd8a38a03a"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:9c286cf11d0ca56368185bada5dd6d97b6be4648fc26c354fcba8df7293718f7" digest = "1:ea539c13b066dac72a940b62f37600a20ab8e88057397c78f3197c1a48475425"
name = "golang.org/x/net"
packages = [
"context",
"http/httpguts",
"http2",
"http2/hpack",
"idna",
"internal/timeseries",
"trace",
]
pruneopts = ""
revision = "351d144fa1fc0bd934e2408202be0c29f25e35a0"
[[projects]]
branch = "master"
digest = "1:f358024b019f87eecaadcb098113a40852c94fe58ea670ef3c3e2d2c7bd93db1"
name = "golang.org/x/sys" name = "golang.org/x/sys"
packages = ["unix"] packages = ["unix"]
pruneopts = "" pruneopts = ""
revision = "bf42f188b9bc6f2cf5b8ee5a912ef1aedd0eba4c" revision = "4ed8d59d0b35e1e29334a206d1b3f38b1e5dfb31"
[[projects]] [[projects]]
digest = "1:5acd3512b047305d49e8763eef7ba423901e85d5dd2fd1e71778a0ea8de10bd4" digest = "1:5acd3512b047305d49e8763eef7ba423901e85d5dd2fd1e71778a0ea8de10bd4"
name = "golang.org/x/text" name = "golang.org/x/text"
packages = [ packages = [
"collate",
"collate/build",
"encoding", "encoding",
"encoding/internal/identifier", "encoding/internal/identifier",
"internal/colltab",
"internal/gen", "internal/gen",
"internal/tag",
"internal/triegen",
"internal/ucd",
"language",
"secure/bidirule",
"transform", "transform",
"unicode/bidi",
"unicode/cldr", "unicode/cldr",
"unicode/norm",
"unicode/rangetable",
] ]
pruneopts = "" pruneopts = ""
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
@ -317,6 +345,54 @@
pruneopts = "" pruneopts = ""
revision = "d0ca3933b724e6be513276cc2edb34e10d667438" revision = "d0ca3933b724e6be513276cc2edb34e10d667438"
[[projects]]
branch = "master"
digest = "1:5fc6c317675b746d0c641b29aa0aab5fcb403c0d07afdbf0de86b0d447a0502a"
name = "google.golang.org/genproto"
packages = ["googleapis/rpc/status"]
pruneopts = ""
revision = "bd91e49a0898e27abb88c339b432fa53d7497ac0"
[[projects]]
digest = "1:d141efe4aaad714e3059c340901aab3147b6253e58c85dafbcca3dd8b0e88ad6"
name = "google.golang.org/grpc"
packages = [
".",
"balancer",
"balancer/base",
"balancer/roundrobin",
"binarylog/grpc_binarylog_v1",
"codes",
"connectivity",
"credentials",
"credentials/internal",
"encoding",
"encoding/proto",
"grpclog",
"internal",
"internal/backoff",
"internal/binarylog",
"internal/channelz",
"internal/envconfig",
"internal/grpcrand",
"internal/grpcsync",
"internal/syscall",
"internal/transport",
"keepalive",
"metadata",
"naming",
"peer",
"resolver",
"resolver/dns",
"resolver/passthrough",
"stats",
"status",
"tap",
]
pruneopts = ""
revision = "df014850f6dee74ba2fc94874043a9f3f75fbfd8"
version = "v1.17.0"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
@ -331,9 +407,8 @@
"github.com/kr/pretty", "github.com/kr/pretty",
"github.com/mattn/go-isatty", "github.com/mattn/go-isatty",
"github.com/pkg/errors", "github.com/pkg/errors",
"github.com/pkg/profile",
"github.com/problame/go-netssh", "github.com/problame/go-netssh",
"github.com/problame/go-rwccmd",
"github.com/problame/go-streamrpc",
"github.com/prometheus/client_golang/prometheus", "github.com/prometheus/client_golang/prometheus",
"github.com/prometheus/client_golang/prometheus/promhttp", "github.com/prometheus/client_golang/prometheus/promhttp",
"github.com/spf13/cobra", "github.com/spf13/cobra",
@ -341,8 +416,13 @@
"github.com/stretchr/testify/assert", "github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require", "github.com/stretchr/testify/require",
"github.com/zrepl/yaml-config", "github.com/zrepl/yaml-config",
"golang.org/x/net/context",
"golang.org/x/sys/unix", "golang.org/x/sys/unix",
"golang.org/x/tools/cmd/stringer", "golang.org/x/tools/cmd/stringer",
"google.golang.org/grpc",
"google.golang.org/grpc/credentials",
"google.golang.org/grpc/keepalive",
"google.golang.org/grpc/peer",
] ]
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -1,8 +1,11 @@
ignored = [ "github.com/inconshreveable/mousetrap" ] ignored = [
"github.com/inconshreveable/mousetrap",
]
[[constraint]] required = [
branch = "master" "golang.org/x/tools/cmd/stringer",
name = "github.com/ftrvxmtrx/fd" "github.com/alvaroloes/enumer",
]
[[constraint]] [[constraint]]
branch = "master" branch = "master"
@ -12,14 +15,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
branch = "master" branch = "master"
name = "github.com/kr/pretty" name = "github.com/kr/pretty"
[[constraint]]
branch = "master"
name = "github.com/mitchellh/go-homedir"
[[constraint]]
branch = "master"
name = "github.com/mitchellh/mapstructure"
[[constraint]] [[constraint]]
name = "github.com/pkg/errors" name = "github.com/pkg/errors"
version = "0.8.0" version = "0.8.0"
@ -28,10 +23,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
branch = "master" branch = "master"
name = "github.com/spf13/cobra" name = "github.com/spf13/cobra"
[[constraint]]
name = "github.com/spf13/viper"
version = "1.0.0"
[[constraint]] [[constraint]]
name = "github.com/stretchr/testify" name = "github.com/stretchr/testify"
version = "1.1.4" version = "1.1.4"
@ -44,10 +35,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
name = "github.com/go-logfmt/logfmt" name = "github.com/go-logfmt/logfmt"
version = "*" version = "*"
[[constraint]]
name = "github.com/problame/go-rwccmd"
branch = "master"
[[constraint]] [[constraint]]
name = "github.com/problame/go-netssh" name = "github.com/problame/go-netssh"
branch = "master" branch = "master"
@ -58,26 +45,17 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
[[constraint]] [[constraint]]
name = "github.com/golang/protobuf" name = "github.com/golang/protobuf"
version = "1.2.0" version = "1"
[[constraint]] [[constraint]]
name = "github.com/fatih/color" name = "github.com/fatih/color"
version = "1.7.0" version = "1.7.0"
[[constraint]]
name = "github.com/problame/go-streamrpc"
version = "0.5.0"
[[constraint]] [[constraint]]
name = "github.com/gdamore/tcell" name = "github.com/gdamore/tcell"
version = "1.0.0" version = "1.0.0"
[[constraint]] [[constraint]]
branch = "master" name = "google.golang.org/grpc"
name = "golang.org/x/tools" version = "1"
[[constraint]]
branch = "master"
name = "github.com/alvaroloes/enumer"

View File

@ -1,45 +1,13 @@
.PHONY: generate build test vet cover release docs docs-clean clean vendordeps .PHONY: generate build test vet cover release docs docs-clean clean vendordeps
.DEFAULT_GOAL := build .DEFAULT_GOAL := build
ROOT := github.com/zrepl/zrepl
SUBPKGS += client
SUBPKGS += config
SUBPKGS += daemon
SUBPKGS += daemon/filters
SUBPKGS += daemon/job
SUBPKGS += daemon/logging
SUBPKGS += daemon/nethelpers
SUBPKGS += daemon/pruner
SUBPKGS += daemon/snapper
SUBPKGS += daemon/streamrpcconfig
SUBPKGS += daemon/transport
SUBPKGS += daemon/transport/connecter
SUBPKGS += daemon/transport/serve
SUBPKGS += endpoint
SUBPKGS += logger
SUBPKGS += pruning
SUBPKGS += pruning/retentiongrid
SUBPKGS += replication
SUBPKGS += replication/fsrep
SUBPKGS += replication/pdu
SUBPKGS += replication/internal/diff
SUBPKGS += tlsconf
SUBPKGS += util
SUBPKGS += util/socketpair
SUBPKGS += util/watchdog
SUBPKGS += util/envconst
SUBPKGS += version
SUBPKGS += zfs
_TESTPKGS := $(ROOT) $(foreach p,$(SUBPKGS),$(ROOT)/$(p))
ARTIFACTDIR := artifacts ARTIFACTDIR := artifacts
ifdef ZREPL_VERSION ifdef ZREPL_VERSION
_ZREPL_VERSION := $(ZREPL_VERSION) _ZREPL_VERSION := $(ZREPL_VERSION)
endif endif
ifndef _ZREPL_VERSION ifndef _ZREPL_VERSION
_ZREPL_VERSION := $(shell git describe --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" ) _ZREPL_VERSION := $(shell git describe --always --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" )
ifeq ($(_ZREPL_VERSION),ZREPL_BUILD_INVALID_VERSION) # can't use .SHELLSTATUS because Debian Stretch is still on gmake 4.1 ifeq ($(_ZREPL_VERSION),ZREPL_BUILD_INVALID_VERSION) # can't use .SHELLSTATUS because Debian Stretch is still on gmake 4.1
$(error cannot infer variable ZREPL_VERSION using git and variable is not overriden by make invocation) $(error cannot infer variable ZREPL_VERSION using git and variable is not overriden by make invocation)
endif endif
@ -59,35 +27,21 @@ vendordeps:
dep ensure -v -vendor-only dep ensure -v -vendor-only
generate: #not part of the build, must do that manually generate: #not part of the build, must do that manually
protoc -I=replication/pdu --go_out=replication/pdu replication/pdu/pdu.proto protoc -I=replication/pdu --go_out=plugins=grpc:replication/pdu replication/pdu/pdu.proto
@for pkg in $(_TESTPKGS); do\ go generate -x ./...
go generate "$$pkg" || exit 1; \
done;
build: build:
@echo "INFO: In case of missing dependencies, run 'make vendordeps'" @echo "INFO: In case of missing dependencies, run 'make vendordeps'"
$(GO_BUILD) -o "$(ARTIFACTDIR)/zrepl" $(GO_BUILD) -o "$(ARTIFACTDIR)/zrepl"
test: test:
@for pkg in $(_TESTPKGS); do \ go test ./...
echo "Testing $$pkg"; \
go test "$$pkg" || exit 1; \
done;
vet: vet:
@for pkg in $(_TESTPKGS); do \ # for each supported platform to cover conditional compilation
echo "Vetting $$pkg"; \ GOOS=linux go vet ./...
go vet "$$pkg" || exit 1; \ GOOS=darwin go vet ./...
done; GOOS=freebsd go vet ./...
cover: artifacts
@for pkg in $(_TESTPKGS); do \
profile="$(ARTIFACTDIR)/cover-$$(basename $$pkg).out"; \
go test -coverprofile "$$profile" $$pkg || exit 1; \
if [ -f "$$profile" ]; then \
go tool cover -html="$$profile" -o "$${profile}.html" || exit 2; \
fi; \
done;
$(ARTIFACTDIR): $(ARTIFACTDIR):
mkdir -p "$@" mkdir -p "$@"
@ -132,7 +86,7 @@ release: $(RELEASE_BINS) $(RELEASE_NOARCH)
cp $^ "$(ARTIFACTDIR)/release" cp $^ "$(ARTIFACTDIR)/release"
cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt
@# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override @# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override
@if git describe --dirty 2>/dev/null | grep dirty >/dev/null; then \ @if git describe --always --dirty 2>/dev/null | grep dirty >/dev/null; then \
echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \ echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \
if [ "$(ZREPL_VERSION)" = "" ]; then \ if [ "$(ZREPL_VERSION)" = "" ]; then \
echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \ echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \

View File

@ -130,7 +130,6 @@ type Global struct {
Monitoring []MonitoringEnum `yaml:"monitoring,optional"` Monitoring []MonitoringEnum `yaml:"monitoring,optional"`
Control *GlobalControl `yaml:"control,optional,fromdefaults"` Control *GlobalControl `yaml:"control,optional,fromdefaults"`
Serve *GlobalServe `yaml:"serve,optional,fromdefaults"` Serve *GlobalServe `yaml:"serve,optional,fromdefaults"`
RPC *RPCConfig `yaml:"rpc,optional,fromdefaults"`
} }
func Default(i interface{}) { func Default(i interface{}) {
@ -145,29 +144,18 @@ func Default(i interface{}) {
} }
} }
type RPCConfig struct {
Timeout time.Duration `yaml:"timeout,optional,positive,default=10s"`
TxChunkSize uint32 `yaml:"tx_chunk_size,optional,default=32768"`
RxStructuredMaxLen uint32 `yaml:"rx_structured_max,optional,default=16777216"`
RxStreamChunkMaxLen uint32 `yaml:"rx_stream_chunk_max,optional,default=16777216"`
RxHeaderMaxLen uint32 `yaml:"rx_header_max,optional,default=40960"`
SendHeartbeatInterval time.Duration `yaml:"send_heartbeat_interval,optional,positive,default=5s"`
}
type ConnectEnum struct { type ConnectEnum struct {
Ret interface{} Ret interface{}
} }
type ConnectCommon struct { type ConnectCommon struct {
Type string `yaml:"type"` Type string `yaml:"type"`
RPC *RPCConfig `yaml:"rpc,optional"`
} }
type TCPConnect struct { type TCPConnect struct {
ConnectCommon `yaml:",inline"` ConnectCommon `yaml:",inline"`
Address string `yaml:"address"` Address string `yaml:"address"`
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
} }
type TLSConnect struct { type TLSConnect struct {
@ -177,7 +165,7 @@ type TLSConnect struct {
Cert string `yaml:"cert"` Cert string `yaml:"cert"`
Key string `yaml:"key"` Key string `yaml:"key"`
ServerCN string `yaml:"server_cn"` ServerCN string `yaml:"server_cn"`
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
} }
type SSHStdinserverConnect struct { type SSHStdinserverConnect struct {
@ -189,7 +177,7 @@ type SSHStdinserverConnect struct {
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
Options []string `yaml:"options,optional"` Options []string `yaml:"options,optional"`
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
} }
type LocalConnect struct { type LocalConnect struct {
@ -204,7 +192,6 @@ type ServeEnum struct {
type ServeCommon struct { type ServeCommon struct {
Type string `yaml:"type"` Type string `yaml:"type"`
RPC *RPCConfig `yaml:"rpc,optional"`
} }
type TCPServe struct { type TCPServe struct {
@ -220,7 +207,7 @@ type TLSServe struct {
Cert string `yaml:"cert"` Cert string `yaml:"cert"`
Key string `yaml:"key"` Key string `yaml:"key"`
ClientCNs []string `yaml:"client_cns"` ClientCNs []string `yaml:"client_cns"`
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"` HandshakeTimeout time.Duration `yaml:"handshake_timeout,zeropositive,default=10s"`
} }
type StdinserverServer struct { type StdinserverServer struct {

View File

@ -1,86 +0,0 @@
package config
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestRPC(t *testing.T) {
conf := testValidConfig(t, `
jobs:
- name: pull_servers
type: pull
connect:
type: tcp
address: "server1.foo.bar:8888"
rpc:
timeout: 20s # different form default, should merge
root_fs: "pool2/backup_servers"
interval: 10m
pruning:
keep_sender:
- type: not_replicated
keep_receiver:
- type: last_n
count: 100
- name: pull_servers2
type: pull
connect:
type: tcp
address: "server1.foo.bar:8888"
rpc:
tx_chunk_size: 0xabcd # different from default, should merge
root_fs: "pool2/backup_servers"
interval: 10m
pruning:
keep_sender:
- type: not_replicated
keep_receiver:
- type: last_n
count: 100
- type: sink
name: "laptop_sink"
root_fs: "pool2/backup_laptops"
serve:
type: tcp
listen: "192.168.122.189:8888"
clients: {
"10.23.42.23":"client1"
}
rpc:
rx_structured_max: 0x2342
- type: sink
name: "other_sink"
root_fs: "pool2/backup_laptops"
serve:
type: tcp
listen: "192.168.122.189:8888"
clients: {
"10.23.42.23":"client1"
}
rpc:
send_heartbeat_interval: 10s
`)
assert.Equal(t, 20*time.Second, conf.Jobs[0].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.Timeout)
assert.Equal(t, uint32(0xabcd), conf.Jobs[1].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.TxChunkSize)
assert.Equal(t, uint32(0x2342), conf.Jobs[2].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.RxStructuredMaxLen)
assert.Equal(t, 10*time.Second, conf.Jobs[3].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.SendHeartbeatInterval)
defConf := RPCConfig{}
Default(&defConf)
assert.Equal(t, defConf.Timeout, conf.Global.RPC.Timeout)
}
func TestGlobal_DefaultRPCConfig(t *testing.T) {
assert.NotPanics(t, func() {
var c RPCConfig
Default(&c)
assert.NotNil(t, c)
assert.Equal(t, c.TxChunkSize, uint32(1)<<15)
})
}

View File

@ -1,11 +1,14 @@
package config package config
import ( import (
"github.com/kr/pretty" "bytes"
"github.com/stretchr/testify/require"
"path" "path"
"path/filepath" "path/filepath"
"testing" "testing"
"text/template"
"github.com/kr/pretty"
"github.com/stretchr/testify/require"
) )
func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) { func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) {
@ -35,8 +38,21 @@ func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) {
} }
// template must be a template/text template with a single '{{ . }}' as placehodler for val
func testValidConfigTemplate(t *testing.T, tmpl string, val string) *Config {
tmp, err := template.New("master").Parse(tmpl)
if err != nil {
panic(err)
}
var buf bytes.Buffer
err = tmp.Execute(&buf, val)
if err != nil {
panic(err)
}
return testValidConfig(t, buf.String())
}
func testValidConfig(t *testing.T, input string) (*Config) { func testValidConfig(t *testing.T, input string) *Config {
t.Helper() t.Helper()
conf, err := testConfig(t, input) conf, err := testConfig(t, input)
require.NoError(t, err) require.NoError(t, err)

View File

@ -2,9 +2,12 @@ package job
import ( import (
"context" "context"
"sync"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-streamrpc"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/filters"
"github.com/zrepl/zrepl/daemon/job/reset" "github.com/zrepl/zrepl/daemon/job/reset"
@ -12,23 +15,22 @@ import (
"github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/logging"
"github.com/zrepl/zrepl/daemon/pruner" "github.com/zrepl/zrepl/daemon/pruner"
"github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/daemon/transport/connecter"
"github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/endpoint"
"github.com/zrepl/zrepl/replication" "github.com/zrepl/zrepl/replication"
"github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/transport/fromconfig"
"github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"sync"
"time"
) )
type ActiveSide struct { type ActiveSide struct {
mode activeMode mode activeMode
name string name string
clientFactory *connecter.ClientFactory connecter transport.Connecter
prunerFactory *pruner.PrunerFactory prunerFactory *pruner.PrunerFactory
promRepStateSecs *prometheus.HistogramVec // labels: state promRepStateSecs *prometheus.HistogramVec // labels: state
promPruneSecs *prometheus.HistogramVec // labels: prune_side promPruneSecs *prometheus.HistogramVec // labels: prune_side
promBytesReplicated *prometheus.CounterVec // labels: filesystem promBytesReplicated *prometheus.CounterVec // labels: filesystem
@ -37,7 +39,6 @@ type ActiveSide struct {
tasks activeSideTasks tasks activeSideTasks
} }
//go:generate enumer -type=ActiveSideState //go:generate enumer -type=ActiveSideState
type ActiveSideState int type ActiveSideState int
@ -48,7 +49,6 @@ const (
ActiveSideDone // also errors ActiveSideDone // also errors
) )
type activeSideTasks struct { type activeSideTasks struct {
state ActiveSideState state ActiveSideState
@ -77,20 +77,44 @@ func (a *ActiveSide) updateTasks(u func(*activeSideTasks)) activeSideTasks {
} }
type activeMode interface { type activeMode interface {
SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) ConnectEndpoints(rpcLoggers rpc.Loggers, connecter transport.Connecter)
DisconnectEndpoints()
SenderReceiver() (replication.Sender, replication.Receiver)
Type() Type Type() Type
RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{})
ResetConnectBackoff()
} }
type modePush struct { type modePush struct {
setupMtx sync.Mutex
sender *endpoint.Sender
receiver *rpc.Client
fsfilter endpoint.FSFilter fsfilter endpoint.FSFilter
snapper *snapper.PeriodicOrManual snapper *snapper.PeriodicOrManual
} }
func (m *modePush) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { func (m *modePush) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) {
sender := endpoint.NewSender(m.fsfilter) m.setupMtx.Lock()
receiver := endpoint.NewRemote(client) defer m.setupMtx.Unlock()
return sender, receiver, nil if m.receiver != nil || m.sender != nil {
panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints")
}
m.sender = endpoint.NewSender(m.fsfilter)
m.receiver = rpc.NewClient(connecter, loggers)
}
func (m *modePush) DisconnectEndpoints() {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
m.receiver.Close()
m.sender = nil
m.receiver = nil
}
func (m *modePush) SenderReceiver() (replication.Sender, replication.Receiver) {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
return m.sender, m.receiver
} }
func (m *modePush) Type() Type { return TypePush } func (m *modePush) Type() Type { return TypePush }
@ -99,6 +123,13 @@ func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan <- struct{
m.snapper.Run(ctx, wakeUpCommon) m.snapper.Run(ctx, wakeUpCommon)
} }
func (m *modePush) ResetConnectBackoff() {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
if m.receiver != nil {
m.receiver.ResetConnectBackoff()
}
}
func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) { func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) {
m := &modePush{} m := &modePush{}
@ -116,14 +147,35 @@ func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error)
} }
type modePull struct { type modePull struct {
setupMtx sync.Mutex
receiver *endpoint.Receiver
sender *rpc.Client
rootFS *zfs.DatasetPath rootFS *zfs.DatasetPath
interval time.Duration interval time.Duration
} }
func (m *modePull) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { func (m *modePull) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) {
sender := endpoint.NewRemote(client) m.setupMtx.Lock()
receiver, err := endpoint.NewReceiver(m.rootFS) defer m.setupMtx.Unlock()
return sender, receiver, err if m.receiver != nil || m.sender != nil {
panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints")
}
m.receiver = endpoint.NewReceiver(m.rootFS, false)
m.sender = rpc.NewClient(connecter, loggers)
}
func (m *modePull) DisconnectEndpoints() {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
m.sender.Close()
m.sender = nil
m.receiver = nil
}
func (m *modePull) SenderReceiver() (replication.Sender, replication.Receiver) {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
return m.sender, m.receiver
} }
func (*modePull) Type() Type { return TypePull } func (*modePull) Type() Type { return TypePull }
@ -148,6 +200,14 @@ func (m *modePull) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}
} }
} }
func (m *modePull) ResetConnectBackoff() {
m.setupMtx.Lock()
defer m.setupMtx.Unlock()
if m.sender != nil {
m.sender.ResetConnectBackoff()
}
}
func modePullFromConfig(g *config.Global, in *config.PullJob) (m *modePull, err error) { func modePullFromConfig(g *config.Global, in *config.PullJob) (m *modePull, err error) {
m = &modePull{} m = &modePull{}
if in.Interval <= 0 { if in.Interval <= 0 {
@ -185,7 +245,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act
ConstLabels: prometheus.Labels{"zrepl_job": j.name}, ConstLabels: prometheus.Labels{"zrepl_job": j.name},
}, []string{"filesystem"}) }, []string{"filesystem"})
j.clientFactory, err = connecter.FromConfig(g, in.Connect) j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot build client") return nil, errors.Wrap(err, "cannot build client")
} }
@ -256,6 +316,7 @@ outer:
break outer break outer
case <-wakeup.Wait(ctx): case <-wakeup.Wait(ctx):
j.mode.ResetConnectBackoff()
case <-periodicDone: case <-periodicDone:
} }
invocationCount++ invocationCount++
@ -268,6 +329,9 @@ func (j *ActiveSide) do(ctx context.Context) {
log := GetLogger(ctx) log := GetLogger(ctx)
ctx = logging.WithSubsystemLoggers(ctx, log) ctx = logging.WithSubsystemLoggers(ctx, log)
loggers := rpc.GetLoggersOrPanic(ctx) // filled by WithSubsystemLoggers
j.mode.ConnectEndpoints(loggers, j.connecter)
defer j.mode.DisconnectEndpoints()
// allow cancellation of an invocation (this function) // allow cancellation of an invocation (this function)
ctx, cancelThisRun := context.WithCancel(ctx) ctx, cancelThisRun := context.WithCancel(ctx)
@ -353,13 +417,7 @@ func (j *ActiveSide) do(ctx context.Context) {
} }
}() }()
client, err := j.clientFactory.NewClient() sender, receiver := j.mode.SenderReceiver()
if err != nil {
log.WithError(err).Error("factory cannot instantiate streamrpc client")
}
defer client.Close(ctx)
sender, receiver, err := j.mode.SenderReceiver(client)
{ {
select { select {

View File

@ -2,28 +2,30 @@ package job
import ( import (
"context" "context"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-streamrpc"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/filters"
"github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/logging"
"github.com/zrepl/zrepl/daemon/transport/serve"
"github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/endpoint"
"github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/transport/fromconfig"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"path"
) )
type PassiveSide struct { type PassiveSide struct {
mode passiveMode mode passiveMode
name string name string
l serve.ListenerFactory listen transport.AuthenticatedListenerFactory
rpcConf *streamrpc.ConnConfig
} }
type passiveMode interface { type passiveMode interface {
ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc Handler() rpc.Handler
RunPeriodic(ctx context.Context) RunPeriodic(ctx context.Context)
Type() Type Type() Type
} }
@ -34,26 +36,8 @@ type modeSink struct {
func (m *modeSink) Type() Type { return TypeSink } func (m *modeSink) Type() Type { return TypeSink }
func (m *modeSink) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { func (m *modeSink) Handler() rpc.Handler {
log := GetLogger(ctx) return endpoint.NewReceiver(m.rootDataset, true)
clientRootStr := path.Join(m.rootDataset.ToString(), conn.ClientIdentity())
clientRoot, err := zfs.NewDatasetPath(clientRootStr)
if err != nil {
log.WithError(err).
WithField("client_identity", conn.ClientIdentity()).
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
}
log.WithField("client_root", clientRoot).Debug("client root")
local, err := endpoint.NewReceiver(clientRoot)
if err != nil {
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
return nil
}
h := endpoint.NewHandler(local)
return h.Handle
} }
func (m *modeSink) RunPeriodic(_ context.Context) {} func (m *modeSink) RunPeriodic(_ context.Context) {}
@ -93,10 +77,8 @@ func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource
func (m *modeSource) Type() Type { return TypeSource } func (m *modeSource) Type() Type { return TypeSource }
func (m *modeSource) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { func (m *modeSource) Handler() rpc.Handler {
sender := endpoint.NewSender(m.fsfilter) return endpoint.NewSender(m.fsfilter)
h := endpoint.NewHandler(sender)
return h.Handle
} }
func (m *modeSource) RunPeriodic(ctx context.Context) { func (m *modeSource) RunPeriodic(ctx context.Context) {
@ -106,8 +88,8 @@ func (m *modeSource) RunPeriodic(ctx context.Context) {
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passiveMode) (s *PassiveSide, err error) { func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passiveMode) (s *PassiveSide, err error) {
s = &PassiveSide{mode: mode, name: in.Name} s = &PassiveSide{mode: mode, name: in.Name}
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil { if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil {
return nil, errors.Wrap(err, "cannot build server") return nil, errors.Wrap(err, "cannot build listener factory")
} }
return s, nil return s, nil
@ -127,70 +109,30 @@ func (j *PassiveSide) Run(ctx context.Context) {
log := GetLogger(ctx) log := GetLogger(ctx)
defer log.Info("job exiting") defer log.Info("job exiting")
ctx = logging.WithSubsystemLoggers(ctx, log)
l, err := j.l.Listen()
if err != nil {
log.WithError(err).Error("cannot listen")
return
}
defer l.Close()
{ {
ctx, cancel := context.WithCancel(logging.WithSubsystemLoggers(ctx, log)) // shadowing ctx, cancel := context.WithCancel(ctx) // shadowing
defer cancel() defer cancel()
go j.mode.RunPeriodic(ctx) go j.mode.RunPeriodic(ctx)
} }
log.WithField("addr", l.Addr()).Debug("accepting connections") handler := j.mode.Handler()
var connId int if handler == nil {
outer: panic(fmt.Sprintf("implementation error: j.mode.Handler() returned nil: %#v", j))
for {
select {
case res := <-accept(ctx, l):
if res.err != nil {
log.WithError(res.err).Info("accept error")
continue
} }
conn := res.conn
connId++ ctxInterceptor := func(handlerCtx context.Context) context.Context {
connLog := log. return logging.WithSubsystemLoggers(handlerCtx, log)
WithField("connID", connId) }
connLog.
WithField("addr", conn.RemoteAddr()). rpcLoggers := rpc.GetLoggersOrPanic(ctx) // WithSubsystemLoggers above
WithField("client_identity", conn.ClientIdentity()). server := rpc.NewServer(handler, rpcLoggers, ctxInterceptor)
Info("handling connection")
go func() { listener, err := j.listen()
defer connLog.Info("finished handling connection") if err != nil {
defer conn.Close() log.WithError(err).Error("cannot listen")
ctx := logging.WithSubsystemLoggers(ctx, connLog)
handleFunc := j.mode.ConnHandleFunc(ctx, conn)
if handleFunc == nil {
return return
} }
if err := streamrpc.ServeConn(ctx, conn, j.rpcConf, handleFunc); err != nil {
log.WithError(err).Error("error serving client")
}
}()
case <-ctx.Done(): server.Serve(ctx, listener)
break outer
}
}
}
type acceptResult struct {
conn serve.AuthenticatedConn
err error
}
func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
c := make(chan acceptResult, 1)
go func() {
conn, err := listener.Accept(ctx)
c <- acceptResult{conn, err}
}()
return c
} }

View File

@ -1,32 +0,0 @@
package logging
import (
"fmt"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/logger"
"strings"
)
type streamrpcLogAdaptor = twoClassLogAdaptor
type twoClassLogAdaptor struct {
logger.Logger
}
var _ streamrpc.Logger = twoClassLogAdaptor{}
func (a twoClassLogAdaptor) Errorf(fmtStr string, args ...interface{}) {
const errorSuffix = ": %s"
if len(args) == 1 {
if err, ok := args[0].(error); ok && strings.HasSuffix(fmtStr, errorSuffix) {
msg := strings.TrimSuffix(fmtStr, errorSuffix)
a.WithError(err).Error(msg)
return
}
}
a.Logger.Error(fmt.Sprintf(fmtStr, args...))
}
func (a twoClassLogAdaptor) Infof(fmtStr string, args ...interface{}) {
a.Logger.Debug(fmt.Sprintf(fmtStr, args...))
}

View File

@ -4,18 +4,21 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"os"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/pruner" "github.com/zrepl/zrepl/daemon/pruner"
"github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/endpoint"
"github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/replication" "github.com/zrepl/zrepl/replication"
"github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/rpc/transportmux"
"github.com/zrepl/zrepl/tlsconf" "github.com/zrepl/zrepl/tlsconf"
"os" "github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/daemon/transport/serve"
) )
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) { func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
@ -60,22 +63,41 @@ func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error)
} }
type Subsystem string
const ( const (
SubsysReplication = "repl" SubsysReplication Subsystem = "repl"
SubsysStreamrpc = "rpc" SubsyEndpoint Subsystem = "endpoint"
SubsyEndpoint = "endpoint" SubsysPruning Subsystem = "pruning"
SubsysSnapshot Subsystem = "snapshot"
SubsysTransport Subsystem = "transport"
SubsysTransportMux Subsystem = "transportmux"
SubsysRPC Subsystem = "rpc"
SubsysRPCControl Subsystem = "rpc.ctrl"
SubsysRPCData Subsystem = "rpc.data"
) )
func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Context { func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Context {
ctx = replication.WithLogger(ctx, log.WithField(SubsysField, "repl")) ctx = replication.WithLogger(ctx, log.WithField(SubsysField, SubsysReplication))
ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{log.WithField(SubsysField, "rpc")}) ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, SubsyEndpoint))
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint")) ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, SubsysPruning))
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning")) ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, SubsysSnapshot))
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot")) ctx = transport.WithLogger(ctx, log.WithField(SubsysField, SubsysTransport))
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve")) ctx = transportmux.WithLogger(ctx, log.WithField(SubsysField, SubsysTransportMux))
ctx = rpc.WithLoggers(ctx,
rpc.Loggers{
General: log.WithField(SubsysField, SubsysRPC),
Control: log.WithField(SubsysField, SubsysRPCControl),
Data: log.WithField(SubsysField, SubsysRPCData),
},
)
return ctx return ctx
} }
func LogSubsystem(log logger.Logger, subsys Subsystem) logger.Logger {
return log.ReplaceField(SubsysField, subsys)
}
func parseLogFormat(i interface{}) (f EntryFormatter, err error) { func parseLogFormat(i interface{}) (f EntryFormatter, err error) {
var is string var is string
switch j := i.(type) { switch j := i.(type) {

View File

@ -7,6 +7,7 @@ import (
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/job" "github.com/zrepl/zrepl/daemon/job"
"github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"net" "net"
"net/http" "net/http"
@ -49,6 +50,10 @@ func (j *prometheusJob) Run(ctx context.Context) {
panic(err) panic(err)
} }
if err := frameconn.PrometheusRegister(prometheus.DefaultRegisterer); err != nil {
panic(err)
}
log := job.GetLogger(ctx) log := job.GetLogger(ctx)
l, err := net.Listen("tcp", j.listen) l, err := net.Listen("tcp", j.listen)

View File

@ -11,7 +11,6 @@ import (
"github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/util/watchdog" "github.com/zrepl/zrepl/util/watchdog"
"github.com/problame/go-streamrpc"
"net" "net"
"sort" "sort"
"strings" "strings"
@ -19,14 +18,15 @@ import (
"time" "time"
) )
// Try to keep it compatible with gitub.com/zrepl/zrepl/replication.Endpoint // Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint
type History interface { type History interface {
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
} }
// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint
type Target interface { type Target interface {
ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error)
ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error)
DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error)
} }
@ -346,7 +346,6 @@ type Error interface {
} }
var _ Error = net.Error(nil) var _ Error = net.Error(nil)
var _ Error = streamrpc.Error(nil)
func shouldRetry(e error) bool { func shouldRetry(e error) bool {
if neterr, ok := e.(net.Error); ok { if neterr, ok := e.(net.Error); ok {
@ -381,10 +380,11 @@ func statePlan(a *args, u updater) state {
ka = &pruner.Progress ka = &pruner.Progress
}) })
tfss, err := target.ListFilesystems(ctx) tfssres, err := target.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
if err != nil { if err != nil {
return onErr(u, err) return onErr(u, err)
} }
tfss := tfssres.GetFilesystems()
pfss := make([]*fs, len(tfss)) pfss := make([]*fs, len(tfss))
for i, tfs := range tfss { for i, tfs := range tfss {
@ -398,11 +398,12 @@ func statePlan(a *args, u updater) state {
} }
pfss[i] = pfs pfss[i] = pfs
tfsvs, err := target.ListFilesystemVersions(ctx, tfs.Path) tfsvsres, err := target.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: tfs.Path})
if err != nil { if err != nil {
l.WithError(err).Error("cannot list filesystem versions") l.WithError(err).Error("cannot list filesystem versions")
return onErr(u, err) return onErr(u, err)
} }
tfsvs := tfsvsres.GetVersions()
// no progress here since we could run in a live-lock (must have used target AND receiver before progress) // no progress here since we could run in a live-lock (must have used target AND receiver before progress)
pfs.snaps = make([]pruning.Snapshot, 0, len(tfsvs)) pfs.snaps = make([]pruning.Snapshot, 0, len(tfsvs))

View File

@ -44,7 +44,7 @@ type mockTarget struct {
destroyErrs map[string][]error destroyErrs map[string][]error
} }
func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { func (t *mockTarget) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
if len(t.listFilesystemsErr) > 0 { if len(t.listFilesystemsErr) > 0 {
e := t.listFilesystemsErr[0] e := t.listFilesystemsErr[0]
t.listFilesystemsErr = t.listFilesystemsErr[1:] t.listFilesystemsErr = t.listFilesystemsErr[1:]
@ -54,10 +54,11 @@ func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, er
for i := range fss { for i := range fss {
fss[i] = t.fss[i].Filesystem() fss[i] = t.fss[i].Filesystem()
} }
return fss, nil return &pdu.ListFilesystemRes{Filesystems: fss}, nil
} }
func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { func (t *mockTarget) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
fs := req.Filesystem
if len(t.listVersionsErrs[fs]) != 0 { if len(t.listVersionsErrs[fs]) != 0 {
e := t.listVersionsErrs[fs][0] e := t.listVersionsErrs[fs][0]
t.listVersionsErrs[fs] = t.listVersionsErrs[fs][1:] t.listVersionsErrs[fs] = t.listVersionsErrs[fs][1:]
@ -68,7 +69,7 @@ func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*
if mfs.path != fs { if mfs.path != fs {
continue continue
} }
return mfs.FilesystemVersions(), nil return &pdu.ListFilesystemVersionsRes{Versions: mfs.FilesystemVersions()}, nil
} }
return nil, fmt.Errorf("filesystem %s does not exist", fs) return nil, fmt.Errorf("filesystem %s does not exist", fs)
} }

View File

@ -177,7 +177,7 @@ func onMainCtxDone(ctx context.Context, u updater) state {
} }
func syncUp(a args, u updater) state { func syncUp(a args, u updater) state {
fss, err := listFSes(a.fsf) fss, err := listFSes(a.ctx, a.fsf)
if err != nil { if err != nil {
return onErr(err, u) return onErr(err, u)
} }
@ -204,7 +204,7 @@ func plan(a args, u updater) state {
u(func(snapper *Snapper) { u(func(snapper *Snapper) {
snapper.lastInvocation = time.Now() snapper.lastInvocation = time.Now()
}) })
fss, err := listFSes(a.fsf) fss, err := listFSes(a.ctx, a.fsf)
if err != nil { if err != nil {
return onErr(err, u) return onErr(err, u)
} }
@ -299,8 +299,8 @@ func wait(a args, u updater) state {
} }
} }
func listFSes(mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) { func listFSes(ctx context.Context, mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) {
return zfs.ZFSListMapping(mf) return zfs.ZFSListMapping(ctx, mf)
} }
func findSyncPoint(log Logger, fss []*zfs.DatasetPath, prefix string, interval time.Duration) (syncPoint time.Time, err error) { func findSyncPoint(log Logger, fss []*zfs.DatasetPath, prefix string, interval time.Duration) (syncPoint time.Time, err error) {

View File

@ -1,25 +0,0 @@
package streamrpcconfig
import (
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/config"
)
func FromDaemonConfig(g *config.Global, in *config.RPCConfig) (*streamrpc.ConnConfig, error) {
conf := in
if conf == nil {
conf = g.RPC
}
srpcConf := &streamrpc.ConnConfig{
RxHeaderMaxLen: conf.RxHeaderMaxLen,
RxStructuredMaxLen: conf.RxStructuredMaxLen,
RxStreamMaxChunkSize: conf.RxStreamChunkMaxLen,
TxChunkSize: conf.TxChunkSize,
Timeout: conf.Timeout,
SendHeartbeatInterval: conf.SendHeartbeatInterval,
}
if err := srpcConf.Validate(); err != nil {
return nil, err
}
return srpcConf, nil
}

View File

@ -1,84 +0,0 @@
package connecter
import (
"context"
"fmt"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
"github.com/zrepl/zrepl/daemon/transport"
"net"
"time"
)
type HandshakeConnecter struct {
connecter streamrpc.Connecter
}
func (c HandshakeConnecter) Connect(ctx context.Context) (net.Conn, error) {
conn, err := c.connecter.Connect(ctx)
if err != nil {
return nil, err
}
dl, ok := ctx.Deadline()
if !ok {
dl = time.Now().Add(10 * time.Second) // FIXME constant
}
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error) {
var (
connecter streamrpc.Connecter
errConnecter, errRPC error
connConf *streamrpc.ConnConfig
)
switch v := in.Ret.(type) {
case *config.SSHStdinserverConnect:
connecter, errConnecter = SSHStdinserverConnecterFromConfig(v)
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.TCPConnect:
connecter, errConnecter = TCPConnecterFromConfig(v)
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.TLSConnect:
connecter, errConnecter = TLSConnecterFromConfig(v)
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.LocalConnect:
connecter, errConnecter = LocalConnecterFromConfig(v)
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
default:
panic(fmt.Sprintf("implementation error: unknown connecter type %T", v))
}
if errConnecter != nil {
return nil, errConnecter
}
if errRPC != nil {
return nil, errRPC
}
config := streamrpc.ClientConfig{ConnConfig: connConf}
if err := config.Validate(); err != nil {
return nil, err
}
connecter = HandshakeConnecter{connecter}
return &ClientFactory{connecter: connecter, config: &config}, nil
}
type ClientFactory struct {
connecter streamrpc.Connecter
config *streamrpc.ClientConfig
}
func (f ClientFactory) NewClient() (*streamrpc.Client, error) {
return streamrpc.NewClient(f.connecter, f.config)
}

View File

@ -1,136 +0,0 @@
package transport
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"time"
"unicode/utf8"
)
type HandshakeMessage struct {
ProtocolVersion int
Extensions []string
}
func (m *HandshakeMessage) Encode() ([]byte, error) {
if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 {
return nil, fmt.Errorf("protocol version must be in [1, 9999]")
}
if len(m.Extensions) >= 9999 {
return nil, fmt.Errorf("protocol only supports [0, 9999] extensions")
}
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
var extensions strings.Builder
for i, ext := range m.Extensions {
if strings.ContainsAny(ext, "\n") {
return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i)
}
if !utf8.ValidString(ext) {
return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i)
}
extensions.WriteString(ext)
extensions.WriteString("\n")
}
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
m.ProtocolVersion, len(m.Extensions), extensions.String())
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
return []byte(withLen), nil
}
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
var lenAndSpace [11]byte
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
return err
}
if !utf8.Valid(lenAndSpace[:]) {
return fmt.Errorf("invalid start of handshake message: not valid UTF-8")
}
var followLen int
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
if n != 1 || err != nil {
return fmt.Errorf("could not parse handshake message length")
}
if followLen > maxLen {
return fmt.Errorf("handshake message length exceeds max length (%d vs %d)",
followLen, maxLen)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
if err != nil {
return err
}
var (
protoVersion, extensionCount int
)
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
&protoVersion, &extensionCount)
if n != 2 || err != nil {
return fmt.Errorf("could not parse handshake message: %s", err)
}
if protoVersion < 1 {
return fmt.Errorf("invalid protocol version %q", protoVersion)
}
m.ProtocolVersion = protoVersion
if extensionCount < 0 {
return fmt.Errorf("invalid extension count %q", extensionCount)
}
if extensionCount == 0 {
if buf.Len() != 0 {
return fmt.Errorf("unexpected data trailing after header")
}
m.Extensions = nil
return nil
}
s := buf.String()
if strings.Count(s, "\n") != extensionCount {
return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
}
exts := strings.Split(s, "\n")
if exts[len(exts)-1] != "" {
return fmt.Errorf("unexpected data trailing after last extension newline")
}
m.Extensions = exts[0:len(exts)-1]
return nil
}
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
// current protocol version is hardcoded here
return DoHandshakeVersion(conn, deadline, 1)
}
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
ours := HandshakeMessage{
ProtocolVersion: version,
Extensions: nil,
}
hsb, err := ours.Encode()
if err != nil {
return fmt.Errorf("could not encode protocol banner: %s", err)
}
conn.SetDeadline(deadline)
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
if err != nil {
return fmt.Errorf("could not send protocol banner: %s", err)
}
theirs := HandshakeMessage{}
if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant
return fmt.Errorf("could not decode protocol banner: %s", err)
}
if theirs.ProtocolVersion != ours.ProtocolVersion {
return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d",
ours.ProtocolVersion, theirs.ProtocolVersion)
}
// ignore extensions, we don't use them
return nil
}

View File

@ -1,147 +0,0 @@
package serve
import (
"github.com/pkg/errors"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/transport"
"net"
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
"github.com/problame/go-streamrpc"
"context"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/zfs"
"time"
)
type contextKey int
const contextKeyLog contextKey = 0
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func getLogger(ctx context.Context) Logger {
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
return log
}
return logger.NewNullLogger()
}
type AuthenticatedConn interface {
net.Conn
// ClientIdentity must be a string that satisfies ValidateClientIdentity
ClientIdentity() string
}
// A client identity must be a single component in a ZFS filesystem path
func ValidateClientIdentity(in string) (err error) {
path, err := zfs.NewDatasetPath(in)
if err != nil {
return err
}
if path.Length() != 1 {
return errors.New("client identity must be a single path comonent (not empty, no '/')")
}
return nil
}
type authConn struct {
net.Conn
clientIdentity string
}
var _ AuthenticatedConn = authConn{}
func (c authConn) ClientIdentity() string {
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
panic(err)
}
return c.clientIdentity
}
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
type AuthenticatedListener interface {
Addr() (net.Addr)
Accept(ctx context.Context) (AuthenticatedConn, error)
Close() error
}
type ListenerFactory interface {
Listen() (AuthenticatedListener, error)
}
type HandshakeListenerFactory struct {
lf ListenerFactory
}
func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := lf.lf.Listen()
if err != nil {
return nil, err
}
return HandshakeListener{l}, nil
}
type HandshakeListener struct {
l AuthenticatedListener
}
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
func (l HandshakeListener) Close() error { return l.l.Close() }
func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
conn, err := l.l.Accept(ctx)
if err != nil {
return nil, err
}
dl, ok := ctx.Deadline()
if !ok {
dl = time.Now().Add(10*time.Second) // FIXME constant
}
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
var (
lfError, rpcErr error
)
switch v := in.Ret.(type) {
case *config.TCPServe:
lf, lfError = TCPListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.TLSServe:
lf, lfError = TLSListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.StdinserverServer:
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.LocalServe:
lf, lfError = LocalListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
default:
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
}
if lfError != nil {
return nil, nil, lfError
}
if rpcErr != nil {
return nil, nil, rpcErr
}
lf = HandshakeListenerFactory{lf}
return lf, conf, nil
}

View File

@ -1,83 +0,0 @@
package serve
import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/tlsconf"
"net"
"time"
"context"
)
type TLSListenerFactory struct {
address string
clientCA *x509.CertPool
serverCert tls.Certificate
handshakeTimeout time.Duration
clientCNs map[string]struct{}
}
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TLSListenerFactory, err error) {
lf = &TLSListenerFactory{
address: in.Listen,
handshakeTimeout: in.HandshakeTimeout,
}
if in.Ca == "" || in.Cert == "" || in.Key == "" {
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
}
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
if err != nil {
return nil, errors.Wrap(err, "cannot parse ca file")
}
lf.serverCert, err = tls.LoadX509KeyPair(in.Cert, in.Key)
if err != nil {
return nil, errors.Wrap(err, "cannot parse cer/key pair")
}
lf.clientCNs = make(map[string]struct{}, len(in.ClientCNs))
for i, cn := range in.ClientCNs {
if err := ValidateClientIdentity(cn); err != nil {
return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn)
}
// dupes are ok fr now
lf.clientCNs[cn] = struct{}{}
}
return lf, nil
}
func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := net.Listen("tcp", f.address)
if err != nil {
return nil, err
}
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
return tlsAuthListener{tl, f.clientCNs}, nil
}
type tlsAuthListener struct {
*tlsconf.ClientAuthListener
clientCNs map[string]struct{}
}
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, cn, err := l.ClientAuthListener.Accept()
if err != nil {
return nil, err
}
if _, ok := l.clientCNs[cn]; !ok {
if err := c.Close(); err != nil {
getLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name")
}
return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, c.RemoteAddr())
}
return authConn{c, cn}, nil
}

View File

@ -203,13 +203,17 @@ The serve & connect configuration will thus look like the following:
``ssh+stdinserver`` Transport ``ssh+stdinserver`` Transport
----------------------------- -----------------------------
``ssh+stdinserver`` is inspired by `git shell <https://git-scm.com/docs/git-shell>`_ and `Borg Backup <https://borgbackup.readthedocs.io/en/stable/deployment.html>`_. ``ssh+stdinserver`` uses the ``ssh`` command and some features of the server-side SSH ``authorized_keys`` file.
It is provided by the Go package ``github.com/problame/go-netssh``. It is less efficient than other transports because the data passes through two more pipes.
However, it is fairly convenient to set up and allows the zrepl daemon to not be directly exposed to the internet, because all traffic passes through the system's SSH server.
.. ATTENTION:: The concept is inspired by `git shell <https://git-scm.com/docs/git-shell>`_ and `Borg Backup <https://borgbackup.readthedocs.io/en/stable/deployment.html>`_.
The implementation is provided by the Go package ``github.com/problame/go-netssh``.
``ssh+stdinserver`` has inferior error detection and handling compared to the ``tcp`` and ``tls`` transports. .. NOTE::
If you require tested timeout & retry handling, use ``tcp`` or ``tls`` transports, or help improve package go-netssh.
``ssh+stdinserver`` generally provides inferior error detection and handling compared to the ``tcp`` and ``tls`` transports.
When encountering such problems, consider using ``tcp`` or ``tls`` transports, or help improve package go-netssh.
.. _transport-ssh+stdinserver-serve: .. _transport-ssh+stdinserver-serve:

View File

@ -9,6 +9,7 @@ type contextKey int
const ( const (
contextKeyLogger contextKey = iota contextKeyLogger contextKey = iota
ClientIdentityKey
) )
type Logger = logger.Logger type Logger = logger.Logger

View File

@ -2,16 +2,14 @@
package endpoint package endpoint
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"github.com/golang/protobuf/proto" "path"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/replication" "github.com/zrepl/zrepl/replication"
"github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"io"
) )
// Sender implements replication.ReplicationEndpoint for a sending side // Sender implements replication.ReplicationEndpoint for a sending side
@ -41,8 +39,8 @@ func (s *Sender) filterCheckFS(fs string) (*zfs.DatasetPath, error) {
return dp, nil return dp, nil
} }
func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { func (s *Sender) ListFilesystems(ctx context.Context, r *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
fss, err := zfs.ZFSListMapping(p.FSFilter) fss, err := zfs.ZFSListMapping(ctx, s.FSFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -53,11 +51,12 @@ func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error)
// FIXME: not supporting ResumeToken yet // FIXME: not supporting ResumeToken yet
} }
} }
return rfss, nil res := &pdu.ListFilesystemRes{Filesystems: rfss, Empty: len(rfss) == 0}
return res, nil
} }
func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { func (s *Sender) ListFilesystemVersions(ctx context.Context, r *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
lp, err := p.filterCheckFS(fs) lp, err := s.filterCheckFS(r.GetFilesystem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,16 +68,17 @@ func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.
for i := range fsvs { for i := range fsvs {
rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i])
} }
return rfsvs, nil res := &pdu.ListFilesystemVersionsRes{Versions: rfsvs}
return res, nil
} }
func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
_, err := p.filterCheckFS(r.Filesystem) _, err := s.filterCheckFS(r.Filesystem)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if r.DryRun {
si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "") si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -87,14 +87,17 @@ func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.Rea
if si.SizeEstimate != -1 { // but si returns -1 for no size estimate if si.SizeEstimate != -1 { // but si returns -1 for no size estimate
expSize = si.SizeEstimate expSize = si.SizeEstimate
} }
return &pdu.SendRes{ExpectedSize: expSize}, nil, nil res := &pdu.SendRes{ExpectedSize: expSize}
} else {
stream, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "") if r.DryRun {
return res, nil, nil
}
streamCopier, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return &pdu.SendRes{}, stream, nil return res, streamCopier, nil
}
} }
func (p *Sender) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { func (p *Sender) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
@ -132,6 +135,10 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs
} }
} }
func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
return nil, fmt.Errorf("sender does not implement Receive()")
}
type FSFilter interface { // FIXME unused type FSFilter interface { // FIXME unused
Filter(path *zfs.DatasetPath) (pass bool, err error) Filter(path *zfs.DatasetPath) (pass bool, err error)
} }
@ -146,14 +153,50 @@ type FSMap interface { // FIXME unused
// Receiver implements replication.ReplicationEndpoint for a receiving side // Receiver implements replication.ReplicationEndpoint for a receiving side
type Receiver struct { type Receiver struct {
root *zfs.DatasetPath rootWithoutClientComponent *zfs.DatasetPath
appendClientIdentity bool
} }
func NewReceiver(rootDataset *zfs.DatasetPath) (*Receiver, error) { func NewReceiver(rootDataset *zfs.DatasetPath, appendClientIdentity bool) *Receiver {
if rootDataset.Length() <= 0 { if rootDataset.Length() <= 0 {
return nil, errors.New("root dataset must not be an empty path") panic(fmt.Sprintf("root dataset must not be an empty path: %v", rootDataset))
} }
return &Receiver{root: rootDataset.Copy()}, nil return &Receiver{rootWithoutClientComponent: rootDataset.Copy(), appendClientIdentity: appendClientIdentity}
}
func TestClientIdentity(rootFS *zfs.DatasetPath, clientIdentity string) error {
_, err := clientRoot(rootFS, clientIdentity)
return err
}
func clientRoot(rootFS *zfs.DatasetPath, clientIdentity string) (*zfs.DatasetPath, error) {
rootFSLen := rootFS.Length()
clientRootStr := path.Join(rootFS.ToString(), clientIdentity)
clientRoot, err := zfs.NewDatasetPath(clientRootStr)
if err != nil {
return nil, err
}
if rootFSLen+1 != clientRoot.Length() {
return nil, fmt.Errorf("client identity must be a single ZFS filesystem path component")
}
return clientRoot, nil
}
func (s *Receiver) clientRootFromCtx(ctx context.Context) *zfs.DatasetPath {
if !s.appendClientIdentity {
return s.rootWithoutClientComponent.Copy()
}
clientIdentity, ok := ctx.Value(ClientIdentityKey).(string)
if !ok {
panic(fmt.Sprintf("ClientIdentityKey context value must be set"))
}
clientRoot, err := clientRoot(s.rootWithoutClientComponent, clientIdentity)
if err != nil {
panic(fmt.Sprintf("ClientIdentityContextKey must have been validated before invoking Receiver: %s", err))
}
return clientRoot
} }
type subroot struct { type subroot struct {
@ -180,8 +223,9 @@ func (f subroot) MapToLocal(fs string) (*zfs.DatasetPath, error) {
return c, nil return c, nil
} }
func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { func (s *Receiver) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
filtered, err := zfs.ZFSListMapping(subroot{e.root}) root := s.clientRootFromCtx(ctx)
filtered, err := zfs.ZFSListMapping(ctx, subroot{root})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -194,19 +238,30 @@ func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, erro
WithError(err). WithError(err).
WithField("fs", a). WithField("fs", a).
Error("inconsistent placeholder property") Error("inconsistent placeholder property")
return nil, errors.New("server error, see logs") // don't leak path return nil, errors.New("server error: inconsistent placeholder property") // don't leak path
} }
if ph { if ph {
getLogger(ctx).
WithField("fs", a.ToString()).
Debug("ignoring placeholder filesystem")
continue continue
} }
a.TrimPrefix(e.root) getLogger(ctx).
WithField("fs", a.ToString()).
Debug("non-placeholder filesystem")
a.TrimPrefix(root)
fss = append(fss, &pdu.Filesystem{Path: a.ToString()}) fss = append(fss, &pdu.Filesystem{Path: a.ToString()})
} }
return fss, nil if len(fss) == 0 {
getLogger(ctx).Debug("no non-placeholder filesystems")
return &pdu.ListFilesystemRes{Empty: true}, nil
}
return &pdu.ListFilesystemRes{Filesystems: fss}, nil
} }
func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { func (s *Receiver) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
lp, err := subroot{e.root}.MapToLocal(fs) root := s.clientRootFromCtx(ctx)
lp, err := subroot{root}.MapToLocal(req.GetFilesystem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -221,18 +276,26 @@ func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pd
rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i])
} }
return rfsvs, nil return &pdu.ListFilesystemVersionsRes{Versions: rfsvs}, nil
} }
func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error { func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
defer sendStream.Close() return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver")
lp, err := subroot{e.root}.MapToLocal(req.Filesystem)
if err != nil {
return err
} }
func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
return nil, nil, fmt.Errorf("receiver does not implement Send()")
}
func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
getLogger(ctx).Debug("incoming Receive") getLogger(ctx).Debug("incoming Receive")
defer receive.Close()
root := s.clientRootFromCtx(ctx)
lp, err := subroot{root}.MapToLocal(req.Filesystem)
if err != nil {
return nil, err
}
// create placeholder parent filesystems as appropriate // create placeholder parent filesystems as appropriate
var visitErr error var visitErr error
@ -261,7 +324,7 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream
getLogger(ctx).WithField("visitErr", visitErr).Debug("complete tree-walk") getLogger(ctx).WithField("visitErr", visitErr).Debug("complete tree-walk")
if visitErr != nil { if visitErr != nil {
return visitErr return nil, err
} }
needForceRecv := false needForceRecv := false
@ -279,19 +342,19 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream
getLogger(ctx).Debug("start receive command") getLogger(ctx).Debug("start receive command")
if err := zfs.ZFSRecv(ctx, lp.ToString(), sendStream, args...); err != nil { if err := zfs.ZFSRecv(ctx, lp.ToString(), receive, args...); err != nil {
getLogger(ctx). getLogger(ctx).
WithError(err). WithError(err).
WithField("args", args). WithField("args", args).
Error("zfs receive failed") Error("zfs receive failed")
sendStream.Close() return nil, err
return err
} }
return nil return &pdu.ReceiveRes{}, nil
} }
func (e *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { func (s *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
lp, err := subroot{e.root}.MapToLocal(req.Filesystem) root := s.clientRootFromCtx(ctx)
lp, err := subroot{root}.MapToLocal(req.Filesystem)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -326,289 +389,3 @@ func doDestroySnapshots(ctx context.Context, lp *zfs.DatasetPath, snaps []*pdu.F
} }
return res, nil return res, nil
} }
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
// RPC STUBS
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
const (
RPCListFilesystems = "ListFilesystems"
RPCListFilesystemVersions = "ListFilesystemVersions"
RPCReceive = "Receive"
RPCSend = "Send"
RPCSDestroySnapshots = "DestroySnapshots"
RPCReplicationCursor = "ReplicationCursor"
)
// Remote implements an endpoint stub that uses streamrpc as a transport.
type Remote struct {
c *streamrpc.Client
}
func NewRemote(c *streamrpc.Client) Remote {
return Remote{c}
}
func (s Remote) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) {
req := pdu.ListFilesystemReq{}
b, err := proto.Marshal(&req)
if err != nil {
return nil, err
}
rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res pdu.ListFilesystemRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
return nil, err
}
return res.Filesystems, nil
}
func (s Remote) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) {
req := pdu.ListFilesystemVersionsReq{
Filesystem: fs,
}
b, err := proto.Marshal(&req)
if err != nil {
return nil, err
}
rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res pdu.ListFilesystemVersionsRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
return nil, err
}
return res.Versions, nil
}
func (s Remote) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
b, err := proto.Marshal(r)
if err != nil {
return nil, nil, err
}
rb, rs, err := s.c.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil)
if err != nil {
return nil, nil, err
}
if !r.DryRun && rs == nil {
return nil, nil, errors.New("response does not contain a stream")
}
if r.DryRun && rs != nil {
rs.Close()
return nil, nil, errors.New("response contains unexpected stream (was dry run)")
}
var res pdu.SendRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
rs.Close()
return nil, nil, err
}
return &res, rs, nil
}
func (s Remote) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error {
defer sendStream.Close()
b, err := proto.Marshal(r)
if err != nil {
return err
}
rb, rs, err := s.c.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream)
getLogger(ctx).WithField("err", err).Debug("Remote.Receive RequestReplyReturned")
if err != nil {
return err
}
if rs != nil {
rs.Close()
return errors.New("response contains unexpected stream")
}
var res pdu.ReceiveRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
return err
}
return nil
}
func (s Remote) DestroySnapshots(ctx context.Context, r *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
b, err := proto.Marshal(r)
if err != nil {
return nil, err
}
rb, rs, err := s.c.RequestReply(ctx, RPCSDestroySnapshots, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res pdu.DestroySnapshotsRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
return nil, err
}
return &res, nil
}
func (s Remote) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
b, err := proto.Marshal(req)
if err != nil {
return nil, err
}
rb, rs, err := s.c.RequestReply(ctx, RPCReplicationCursor, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res pdu.ReplicationCursorRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
return nil, err
}
return &res, nil
}
// Handler implements the server-side streamrpc.HandlerFunc for a Remote endpoint stub.
type Handler struct {
ep replication.Endpoint
}
func NewHandler(ep replication.Endpoint) Handler {
return Handler{ep}
}
func (a *Handler) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) {
switch endpoint {
case RPCListFilesystems:
var req pdu.ListFilesystemReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
fsses, err := a.ep.ListFilesystems(ctx)
if err != nil {
return nil, nil, err
}
res := &pdu.ListFilesystemRes{
Filesystems: fsses,
}
b, err := proto.Marshal(res)
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), nil, nil
case RPCListFilesystemVersions:
var req pdu.ListFilesystemVersionsReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem)
if err != nil {
return nil, nil, err
}
res := &pdu.ListFilesystemVersionsRes{
Versions: fsvs,
}
b, err := proto.Marshal(res)
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), nil, nil
case RPCSend:
sender, ok := a.ep.(replication.Sender)
if !ok {
goto Err
}
var req pdu.SendReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
res, sendStream, err := sender.Send(ctx, &req)
if err != nil {
return nil, nil, err
}
b, err := proto.Marshal(res)
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), sendStream, err
case RPCReceive:
receiver, ok := a.ep.(replication.Receiver)
if !ok {
goto Err
}
var req pdu.ReceiveReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
err := receiver.Receive(ctx, &req, reqStream)
if err != nil {
return nil, nil, err
}
b, err := proto.Marshal(&pdu.ReceiveRes{})
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), nil, err
case RPCSDestroySnapshots:
var req pdu.DestroySnapshotsReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
res, err := a.ep.DestroySnapshots(ctx, &req)
if err != nil {
return nil, nil, err
}
b, err := proto.Marshal(res)
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), nil, nil
case RPCReplicationCursor:
sender, ok := a.ep.(replication.Sender)
if !ok {
goto Err
}
var req pdu.ReplicationCursorReq
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
res, err := sender.ReplicationCursor(ctx, &req)
if err != nil {
return nil, nil, err
}
b, err := proto.Marshal(res)
if err != nil {
return nil, nil, err
}
return bytes.NewBuffer(b), nil, nil
}
Err:
return nil, nil, errors.New("no handler for given endpoint")
}

27
logger/stderrlogger.go Normal file
View File

@ -0,0 +1,27 @@
package logger
import (
"fmt"
"os"
)
type stderrLogger struct {
Logger
}
type stderrLoggerOutlet struct {}
func (stderrLoggerOutlet) WriteEntry(entry Entry) error {
fmt.Fprintf(os.Stderr, "%#v\n", entry)
return nil
}
var _ Logger = testLogger{}
func NewStderrDebugLogger() Logger {
outlets := NewOutlets()
outlets.Add(&stderrLoggerOutlet{}, Debug)
return &testLogger{
Logger: NewLogger(outlets, 0),
}
}

View File

@ -6,16 +6,16 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/prometheus/client_golang/prometheus"
"github.com/zrepl/zrepl/util/watchdog"
"io"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/prometheus/client_golang/prometheus"
"github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/util" "github.com/zrepl/zrepl/util/bytecounter"
"github.com/zrepl/zrepl/util/watchdog"
"github.com/zrepl/zrepl/zfs"
) )
type contextKey int type contextKey int
@ -43,7 +43,7 @@ type Sender interface {
// If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before // If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before
// any next call to the parent github.com/zrepl/zrepl/replication.Endpoint. // any next call to the parent github.com/zrepl/zrepl/replication.Endpoint.
// If the send request is for dry run the io.ReadCloser will be nil // If the send request is for dry run the io.ReadCloser will be nil
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
} }
@ -51,9 +51,7 @@ type Sender interface {
type Receiver interface { type Receiver interface {
// Receive sends r and sendStream (the latter containing a ZFS send stream) // Receive sends r and sendStream (the latter containing a ZFS send stream)
// to the parent github.com/zrepl/zrepl/replication.Endpoint. // to the parent github.com/zrepl/zrepl/replication.Endpoint.
// Implementors must guarantee that Close was called on sendStream before Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
// the call to Receive returns.
Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error
} }
type StepReport struct { type StepReport struct {
@ -227,7 +225,7 @@ type ReplicationStep struct {
// both retry and permanent error // both retry and permanent error
err error err error
byteCounter *util.ByteCounterReader byteCounter bytecounter.StreamCopier
expectedSize int64 // 0 means no size estimate present / possible expectedSize int64 // 0 means no size estimate present / possible
} }
@ -401,37 +399,54 @@ func (s *ReplicationStep) doReplication(ctx context.Context, ka *watchdog.KeepAl
sr := s.buildSendRequest(false) sr := s.buildSendRequest(false)
log.Debug("initiate send request") log.Debug("initiate send request")
sres, sstream, err := sender.Send(ctx, sr) sres, sstreamCopier, err := sender.Send(ctx, sr)
if err != nil { if err != nil {
log.WithError(err).Error("send request failed") log.WithError(err).Error("send request failed")
return err return err
} }
if sstream == nil { if sstreamCopier == nil {
err := errors.New("send request did not return a stream, broken endpoint implementation") err := errors.New("send request did not return a stream, broken endpoint implementation")
return err return err
} }
defer sstreamCopier.Close()
s.byteCounter = util.NewByteCounterReader(sstream) // Install a byte counter to track progress + for status report
s.byteCounter.SetCallback(1*time.Second, func(i int64) { s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier)
byteCounterStopProgress := make(chan struct{})
defer close(byteCounterStopProgress)
go func() {
var lastCount int64
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-byteCounterStopProgress:
return
case <-t.C:
newCount := s.byteCounter.Count()
if lastCount != newCount {
ka.MadeProgress() ka.MadeProgress()
}) } else {
defer func() { lastCount = newCount
s.parent.promBytesReplicated.Add(float64(s.byteCounter.Bytes())) }
}
}
}()
defer func() {
s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count()))
}() }()
sstream = s.byteCounter
rr := &pdu.ReceiveReq{ rr := &pdu.ReceiveReq{
Filesystem: fs, Filesystem: fs,
ClearResumeToken: !sres.UsedResumeToken, ClearResumeToken: !sres.UsedResumeToken,
} }
log.Debug("initiate receive request") log.Debug("initiate receive request")
err = receiver.Receive(ctx, rr, sstream) _, err = receiver.Receive(ctx, rr, s.byteCounter)
if err != nil { if err != nil {
log. log.
WithError(err). WithError(err).
WithField("errType", fmt.Sprintf("%T", err)). WithField("errType", fmt.Sprintf("%T", err)).
Error("receive request failed (might also be error on sender)") Error("receive request failed (might also be error on sender)")
sstream.Close()
// This failure could be due to // This failure could be due to
// - an unexpected exit of ZFS on the sending side // - an unexpected exit of ZFS on the sending side
// - an unexpected exit of ZFS on the receiving side // - an unexpected exit of ZFS on the receiving side
@ -524,7 +539,7 @@ func (s *ReplicationStep) Report() *StepReport {
} }
bytes := int64(0) bytes := int64(0)
if s.byteCounter != nil { if s.byteCounter != nil {
bytes = s.byteCounter.Bytes() bytes = s.byteCounter.Count()
} }
problem := "" problem := ""
if s.err != nil { if s.err != nil {

View File

@ -10,7 +10,6 @@ import (
"github.com/zrepl/zrepl/daemon/job/wakeup" "github.com/zrepl/zrepl/daemon/job/wakeup"
"github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/util/watchdog" "github.com/zrepl/zrepl/util/watchdog"
"github.com/problame/go-streamrpc"
"math/bits" "math/bits"
"net" "net"
"sort" "sort"
@ -106,9 +105,8 @@ func NewReplication(secsPerState *prometheus.HistogramVec, bytesReplicated *prom
// named interfaces defined in this package. // named interfaces defined in this package.
type Endpoint interface { type Endpoint interface {
// Does not include placeholder filesystems // Does not include placeholder filesystems
ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error)
// FIXME document FilteredError handling ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error)
ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS
DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error)
} }
@ -203,7 +201,6 @@ type Error interface {
var _ Error = fsrep.Error(nil) var _ Error = fsrep.Error(nil)
var _ Error = net.Error(nil) var _ Error = net.Error(nil)
var _ Error = streamrpc.Error(nil)
func isPermanent(err error) bool { func isPermanent(err error) bool {
if e, ok := err.(Error); ok { if e, ok := err.(Error); ok {
@ -232,19 +229,20 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
}).rsf() }).rsf()
} }
sfss, err := sender.ListFilesystems(ctx) slfssres, err := sender.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
if err != nil { if err != nil {
log.WithError(err).Error("error listing sender filesystems") log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing sender filesystems")
return handlePlanningError(err) return handlePlanningError(err)
} }
sfss := slfssres.GetFilesystems()
// no progress here since we could run in a live-lock on connectivity issues // no progress here since we could run in a live-lock on connectivity issues
rfss, err := receiver.ListFilesystems(ctx) rlfssres, err := receiver.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
if err != nil { if err != nil {
log.WithError(err).Error("error listing receiver filesystems") log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing receiver filesystems")
return handlePlanningError(err) return handlePlanningError(err)
} }
rfss := rlfssres.GetFilesystems()
ka.MadeProgress() // for both sender and receiver ka.MadeProgress() // for both sender and receiver
q := make([]*fsrep.Replication, 0, len(sfss)) q := make([]*fsrep.Replication, 0, len(sfss))
@ -255,11 +253,12 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
log.Debug("assessing filesystem") log.Debug("assessing filesystem")
sfsvs, err := sender.ListFilesystemVersions(ctx, fs.Path) sfsvsres, err := sender.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path})
if err != nil { if err != nil {
log.WithError(err).Error("cannot get remote filesystem versions") log.WithError(err).Error("cannot get remote filesystem versions")
return handlePlanningError(err) return handlePlanningError(err)
} }
sfsvs := sfsvsres.GetVersions()
ka.MadeProgress() ka.MadeProgress()
if len(sfsvs) < 1 { if len(sfsvs) < 1 {
@ -278,7 +277,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
var rfsvs []*pdu.FilesystemVersion var rfsvs []*pdu.FilesystemVersion
if receiverFSExists { if receiverFSExists {
rfsvs, err = receiver.ListFilesystemVersions(ctx, fs.Path) rfsvsres, err := receiver.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path})
if err != nil { if err != nil {
if _, ok := err.(*FilteredError); ok { if _, ok := err.(*FilteredError); ok {
log.Info("receiver ignores filesystem") log.Info("receiver ignores filesystem")
@ -287,6 +286,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
log.WithError(err).Error("receiver error") log.WithError(err).Error("receiver error")
return handlePlanningError(err) return handlePlanningError(err)
} }
rfsvs = rfsvsres.GetVersions()
} else { } else {
rfsvs = []*pdu.FilesystemVersion{} rfsvs = []*pdu.FilesystemVersion{}
} }

View File

@ -7,6 +7,11 @@ import proto "github.com/golang/protobuf/proto"
import fmt "fmt" import fmt "fmt"
import math "math" import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal var _ = proto.Marshal
var _ = fmt.Errorf var _ = fmt.Errorf
@ -38,7 +43,7 @@ func (x FilesystemVersion_VersionType) String() string {
return proto.EnumName(FilesystemVersion_VersionType_name, int32(x)) return proto.EnumName(FilesystemVersion_VersionType_name, int32(x))
} }
func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) { func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5, 0} return fileDescriptor_pdu_89315d819a6e0938, []int{5, 0}
} }
type ListFilesystemReq struct { type ListFilesystemReq struct {
@ -51,7 +56,7 @@ func (m *ListFilesystemReq) Reset() { *m = ListFilesystemReq{} }
func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) } func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) }
func (*ListFilesystemReq) ProtoMessage() {} func (*ListFilesystemReq) ProtoMessage() {}
func (*ListFilesystemReq) Descriptor() ([]byte, []int) { func (*ListFilesystemReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{0} return fileDescriptor_pdu_89315d819a6e0938, []int{0}
} }
func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error { func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b) return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b)
@ -73,6 +78,7 @@ var xxx_messageInfo_ListFilesystemReq proto.InternalMessageInfo
type ListFilesystemRes struct { type ListFilesystemRes struct {
Filesystems []*Filesystem `protobuf:"bytes,1,rep,name=Filesystems,proto3" json:"Filesystems,omitempty"` Filesystems []*Filesystem `protobuf:"bytes,1,rep,name=Filesystems,proto3" json:"Filesystems,omitempty"`
Empty bool `protobuf:"varint,2,opt,name=Empty,proto3" json:"Empty,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"` XXX_sizecache int32 `json:"-"`
@ -82,7 +88,7 @@ func (m *ListFilesystemRes) Reset() { *m = ListFilesystemRes{} }
func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) } func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) }
func (*ListFilesystemRes) ProtoMessage() {} func (*ListFilesystemRes) ProtoMessage() {}
func (*ListFilesystemRes) Descriptor() ([]byte, []int) { func (*ListFilesystemRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{1} return fileDescriptor_pdu_89315d819a6e0938, []int{1}
} }
func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error { func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b) return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b)
@ -109,6 +115,13 @@ func (m *ListFilesystemRes) GetFilesystems() []*Filesystem {
return nil return nil
} }
func (m *ListFilesystemRes) GetEmpty() bool {
if m != nil {
return m.Empty
}
return false
}
type Filesystem struct { type Filesystem struct {
Path string `protobuf:"bytes,1,opt,name=Path,proto3" json:"Path,omitempty"` Path string `protobuf:"bytes,1,opt,name=Path,proto3" json:"Path,omitempty"`
ResumeToken string `protobuf:"bytes,2,opt,name=ResumeToken,proto3" json:"ResumeToken,omitempty"` ResumeToken string `protobuf:"bytes,2,opt,name=ResumeToken,proto3" json:"ResumeToken,omitempty"`
@ -121,7 +134,7 @@ func (m *Filesystem) Reset() { *m = Filesystem{} }
func (m *Filesystem) String() string { return proto.CompactTextString(m) } func (m *Filesystem) String() string { return proto.CompactTextString(m) }
func (*Filesystem) ProtoMessage() {} func (*Filesystem) ProtoMessage() {}
func (*Filesystem) Descriptor() ([]byte, []int) { func (*Filesystem) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{2} return fileDescriptor_pdu_89315d819a6e0938, []int{2}
} }
func (m *Filesystem) XXX_Unmarshal(b []byte) error { func (m *Filesystem) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Filesystem.Unmarshal(m, b) return xxx_messageInfo_Filesystem.Unmarshal(m, b)
@ -166,7 +179,7 @@ func (m *ListFilesystemVersionsReq) Reset() { *m = ListFilesystemVersion
func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) } func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) }
func (*ListFilesystemVersionsReq) ProtoMessage() {} func (*ListFilesystemVersionsReq) ProtoMessage() {}
func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) { func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{3} return fileDescriptor_pdu_89315d819a6e0938, []int{3}
} }
func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error { func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b) return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b)
@ -204,7 +217,7 @@ func (m *ListFilesystemVersionsRes) Reset() { *m = ListFilesystemVersion
func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) } func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) }
func (*ListFilesystemVersionsRes) ProtoMessage() {} func (*ListFilesystemVersionsRes) ProtoMessage() {}
func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) { func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{4} return fileDescriptor_pdu_89315d819a6e0938, []int{4}
} }
func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error { func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b) return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b)
@ -232,7 +245,7 @@ func (m *ListFilesystemVersionsRes) GetVersions() []*FilesystemVersion {
} }
type FilesystemVersion struct { type FilesystemVersion struct {
Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=pdu.FilesystemVersion_VersionType" json:"Type,omitempty"` Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=FilesystemVersion_VersionType" json:"Type,omitempty"`
Name string `protobuf:"bytes,2,opt,name=Name,proto3" json:"Name,omitempty"` Name string `protobuf:"bytes,2,opt,name=Name,proto3" json:"Name,omitempty"`
Guid uint64 `protobuf:"varint,3,opt,name=Guid,proto3" json:"Guid,omitempty"` Guid uint64 `protobuf:"varint,3,opt,name=Guid,proto3" json:"Guid,omitempty"`
CreateTXG uint64 `protobuf:"varint,4,opt,name=CreateTXG,proto3" json:"CreateTXG,omitempty"` CreateTXG uint64 `protobuf:"varint,4,opt,name=CreateTXG,proto3" json:"CreateTXG,omitempty"`
@ -246,7 +259,7 @@ func (m *FilesystemVersion) Reset() { *m = FilesystemVersion{} }
func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) } func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) }
func (*FilesystemVersion) ProtoMessage() {} func (*FilesystemVersion) ProtoMessage() {}
func (*FilesystemVersion) Descriptor() ([]byte, []int) { func (*FilesystemVersion) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5} return fileDescriptor_pdu_89315d819a6e0938, []int{5}
} }
func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error { func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b) return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b)
@ -326,7 +339,7 @@ func (m *SendReq) Reset() { *m = SendReq{} }
func (m *SendReq) String() string { return proto.CompactTextString(m) } func (m *SendReq) String() string { return proto.CompactTextString(m) }
func (*SendReq) ProtoMessage() {} func (*SendReq) ProtoMessage() {}
func (*SendReq) Descriptor() ([]byte, []int) { func (*SendReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{6} return fileDescriptor_pdu_89315d819a6e0938, []int{6}
} }
func (m *SendReq) XXX_Unmarshal(b []byte) error { func (m *SendReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SendReq.Unmarshal(m, b) return xxx_messageInfo_SendReq.Unmarshal(m, b)
@ -407,7 +420,7 @@ func (m *Property) Reset() { *m = Property{} }
func (m *Property) String() string { return proto.CompactTextString(m) } func (m *Property) String() string { return proto.CompactTextString(m) }
func (*Property) ProtoMessage() {} func (*Property) ProtoMessage() {}
func (*Property) Descriptor() ([]byte, []int) { func (*Property) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{7} return fileDescriptor_pdu_89315d819a6e0938, []int{7}
} }
func (m *Property) XXX_Unmarshal(b []byte) error { func (m *Property) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Property.Unmarshal(m, b) return xxx_messageInfo_Property.Unmarshal(m, b)
@ -443,11 +456,11 @@ func (m *Property) GetValue() string {
type SendRes struct { type SendRes struct {
// Whether the resume token provided in the request has been used or not. // Whether the resume token provided in the request has been used or not.
UsedResumeToken bool `protobuf:"varint,1,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"` UsedResumeToken bool `protobuf:"varint,2,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"`
// Expected stream size determined by dry run, not exact. // Expected stream size determined by dry run, not exact.
// 0 indicates that for the given SendReq, no size estimate could be made. // 0 indicates that for the given SendReq, no size estimate could be made.
ExpectedSize int64 `protobuf:"varint,2,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"` ExpectedSize int64 `protobuf:"varint,3,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"`
Properties []*Property `protobuf:"bytes,3,rep,name=Properties,proto3" json:"Properties,omitempty"` Properties []*Property `protobuf:"bytes,4,rep,name=Properties,proto3" json:"Properties,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"` XXX_sizecache int32 `json:"-"`
@ -457,7 +470,7 @@ func (m *SendRes) Reset() { *m = SendRes{} }
func (m *SendRes) String() string { return proto.CompactTextString(m) } func (m *SendRes) String() string { return proto.CompactTextString(m) }
func (*SendRes) ProtoMessage() {} func (*SendRes) ProtoMessage() {}
func (*SendRes) Descriptor() ([]byte, []int) { func (*SendRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{8} return fileDescriptor_pdu_89315d819a6e0938, []int{8}
} }
func (m *SendRes) XXX_Unmarshal(b []byte) error { func (m *SendRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SendRes.Unmarshal(m, b) return xxx_messageInfo_SendRes.Unmarshal(m, b)
@ -511,7 +524,7 @@ func (m *ReceiveReq) Reset() { *m = ReceiveReq{} }
func (m *ReceiveReq) String() string { return proto.CompactTextString(m) } func (m *ReceiveReq) String() string { return proto.CompactTextString(m) }
func (*ReceiveReq) ProtoMessage() {} func (*ReceiveReq) ProtoMessage() {}
func (*ReceiveReq) Descriptor() ([]byte, []int) { func (*ReceiveReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{9} return fileDescriptor_pdu_89315d819a6e0938, []int{9}
} }
func (m *ReceiveReq) XXX_Unmarshal(b []byte) error { func (m *ReceiveReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReceiveReq.Unmarshal(m, b) return xxx_messageInfo_ReceiveReq.Unmarshal(m, b)
@ -555,7 +568,7 @@ func (m *ReceiveRes) Reset() { *m = ReceiveRes{} }
func (m *ReceiveRes) String() string { return proto.CompactTextString(m) } func (m *ReceiveRes) String() string { return proto.CompactTextString(m) }
func (*ReceiveRes) ProtoMessage() {} func (*ReceiveRes) ProtoMessage() {}
func (*ReceiveRes) Descriptor() ([]byte, []int) { func (*ReceiveRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{10} return fileDescriptor_pdu_89315d819a6e0938, []int{10}
} }
func (m *ReceiveRes) XXX_Unmarshal(b []byte) error { func (m *ReceiveRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReceiveRes.Unmarshal(m, b) return xxx_messageInfo_ReceiveRes.Unmarshal(m, b)
@ -588,7 +601,7 @@ func (m *DestroySnapshotsReq) Reset() { *m = DestroySnapshotsReq{} }
func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) } func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) }
func (*DestroySnapshotsReq) ProtoMessage() {} func (*DestroySnapshotsReq) ProtoMessage() {}
func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) { func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{11} return fileDescriptor_pdu_89315d819a6e0938, []int{11}
} }
func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error { func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b) return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b)
@ -634,7 +647,7 @@ func (m *DestroySnapshotRes) Reset() { *m = DestroySnapshotRes{} }
func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) } func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) }
func (*DestroySnapshotRes) ProtoMessage() {} func (*DestroySnapshotRes) ProtoMessage() {}
func (*DestroySnapshotRes) Descriptor() ([]byte, []int) { func (*DestroySnapshotRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{12} return fileDescriptor_pdu_89315d819a6e0938, []int{12}
} }
func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error { func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b) return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b)
@ -679,7 +692,7 @@ func (m *DestroySnapshotsRes) Reset() { *m = DestroySnapshotsRes{} }
func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) } func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) }
func (*DestroySnapshotsRes) ProtoMessage() {} func (*DestroySnapshotsRes) ProtoMessage() {}
func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) { func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{13} return fileDescriptor_pdu_89315d819a6e0938, []int{13}
} }
func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error { func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b) return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b)
@ -721,7 +734,7 @@ func (m *ReplicationCursorReq) Reset() { *m = ReplicationCursorReq{} }
func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) } func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) }
func (*ReplicationCursorReq) ProtoMessage() {} func (*ReplicationCursorReq) ProtoMessage() {}
func (*ReplicationCursorReq) Descriptor() ([]byte, []int) { func (*ReplicationCursorReq) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14} return fileDescriptor_pdu_89315d819a6e0938, []int{14}
} }
func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error { func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b) return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b)
@ -869,7 +882,7 @@ func (m *ReplicationCursorReq_GetOp) Reset() { *m = ReplicationCursorReq
func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) } func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) }
func (*ReplicationCursorReq_GetOp) ProtoMessage() {} func (*ReplicationCursorReq_GetOp) ProtoMessage() {}
func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) { func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 0} return fileDescriptor_pdu_89315d819a6e0938, []int{14, 0}
} }
func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error { func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b) return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b)
@ -900,7 +913,7 @@ func (m *ReplicationCursorReq_SetOp) Reset() { *m = ReplicationCursorReq
func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) } func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) }
func (*ReplicationCursorReq_SetOp) ProtoMessage() {} func (*ReplicationCursorReq_SetOp) ProtoMessage() {}
func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) { func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 1} return fileDescriptor_pdu_89315d819a6e0938, []int{14, 1}
} }
func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error { func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b) return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b)
@ -941,7 +954,7 @@ func (m *ReplicationCursorRes) Reset() { *m = ReplicationCursorRes{} }
func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) } func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) }
func (*ReplicationCursorRes) ProtoMessage() {} func (*ReplicationCursorRes) ProtoMessage() {}
func (*ReplicationCursorRes) Descriptor() ([]byte, []int) { func (*ReplicationCursorRes) Descriptor() ([]byte, []int) {
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{15} return fileDescriptor_pdu_89315d819a6e0938, []int{15}
} }
func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error { func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b) return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b)
@ -1067,71 +1080,246 @@ func _ReplicationCursorRes_OneofSizer(msg proto.Message) (n int) {
} }
func init() { func init() {
proto.RegisterType((*ListFilesystemReq)(nil), "pdu.ListFilesystemReq") proto.RegisterType((*ListFilesystemReq)(nil), "ListFilesystemReq")
proto.RegisterType((*ListFilesystemRes)(nil), "pdu.ListFilesystemRes") proto.RegisterType((*ListFilesystemRes)(nil), "ListFilesystemRes")
proto.RegisterType((*Filesystem)(nil), "pdu.Filesystem") proto.RegisterType((*Filesystem)(nil), "Filesystem")
proto.RegisterType((*ListFilesystemVersionsReq)(nil), "pdu.ListFilesystemVersionsReq") proto.RegisterType((*ListFilesystemVersionsReq)(nil), "ListFilesystemVersionsReq")
proto.RegisterType((*ListFilesystemVersionsRes)(nil), "pdu.ListFilesystemVersionsRes") proto.RegisterType((*ListFilesystemVersionsRes)(nil), "ListFilesystemVersionsRes")
proto.RegisterType((*FilesystemVersion)(nil), "pdu.FilesystemVersion") proto.RegisterType((*FilesystemVersion)(nil), "FilesystemVersion")
proto.RegisterType((*SendReq)(nil), "pdu.SendReq") proto.RegisterType((*SendReq)(nil), "SendReq")
proto.RegisterType((*Property)(nil), "pdu.Property") proto.RegisterType((*Property)(nil), "Property")
proto.RegisterType((*SendRes)(nil), "pdu.SendRes") proto.RegisterType((*SendRes)(nil), "SendRes")
proto.RegisterType((*ReceiveReq)(nil), "pdu.ReceiveReq") proto.RegisterType((*ReceiveReq)(nil), "ReceiveReq")
proto.RegisterType((*ReceiveRes)(nil), "pdu.ReceiveRes") proto.RegisterType((*ReceiveRes)(nil), "ReceiveRes")
proto.RegisterType((*DestroySnapshotsReq)(nil), "pdu.DestroySnapshotsReq") proto.RegisterType((*DestroySnapshotsReq)(nil), "DestroySnapshotsReq")
proto.RegisterType((*DestroySnapshotRes)(nil), "pdu.DestroySnapshotRes") proto.RegisterType((*DestroySnapshotRes)(nil), "DestroySnapshotRes")
proto.RegisterType((*DestroySnapshotsRes)(nil), "pdu.DestroySnapshotsRes") proto.RegisterType((*DestroySnapshotsRes)(nil), "DestroySnapshotsRes")
proto.RegisterType((*ReplicationCursorReq)(nil), "pdu.ReplicationCursorReq") proto.RegisterType((*ReplicationCursorReq)(nil), "ReplicationCursorReq")
proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "pdu.ReplicationCursorReq.GetOp") proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "ReplicationCursorReq.GetOp")
proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "pdu.ReplicationCursorReq.SetOp") proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "ReplicationCursorReq.SetOp")
proto.RegisterType((*ReplicationCursorRes)(nil), "pdu.ReplicationCursorRes") proto.RegisterType((*ReplicationCursorRes)(nil), "ReplicationCursorRes")
proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) proto.RegisterEnum("FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value)
} }
func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_fe566e6b212fcf8d) } // Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
var fileDescriptor_pdu_fe566e6b212fcf8d = []byte{ // This is a compile-time assertion to ensure that this generated file
// 659 bytes of a gzipped FileDescriptorProto // is compatible with the grpc package it is being compiled against.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdb, 0x6e, 0x13, 0x31, const _ = grpc.SupportPackageIsVersion4
0x10, 0xcd, 0xe6, 0xba, 0x99, 0x94, 0x5e, 0xdc, 0xaa, 0x2c, 0x15, 0x82, 0xc8, 0xbc, 0x04, 0x24,
0x22, 0x91, 0x56, 0xbc, 0xf0, 0x96, 0xde, 0xf2, 0x80, 0xda, 0xca, 0x09, 0x55, 0x9f, 0x90, 0x42, // ReplicationClient is the client API for Replication service.
0x77, 0x44, 0x57, 0xb9, 0x78, 0x6b, 0x7b, 0x51, 0xc3, 0x07, 0xf0, 0x4f, 0xfc, 0x07, 0x0f, 0x7c, //
0x0e, 0xf2, 0xec, 0x25, 0xdb, 0x24, 0x54, 0x79, 0x8a, 0xcf, 0xf8, 0x78, 0xe6, 0xcc, 0xf1, 0x8e, // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
0x03, 0xf5, 0xd0, 0x8f, 0xda, 0xa1, 0x92, 0x46, 0xb2, 0x52, 0xe8, 0x47, 0x7c, 0x17, 0x76, 0x3e, type ReplicationClient interface {
0x07, 0xda, 0x9c, 0x05, 0x63, 0xd4, 0x33, 0x6d, 0x70, 0x22, 0xf0, 0x9e, 0x9f, 0x2d, 0x07, 0x35, ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error)
0xfb, 0x00, 0x8d, 0x79, 0x40, 0x7b, 0x4e, 0xb3, 0xd4, 0x6a, 0x74, 0xb6, 0xda, 0x36, 0x5f, 0x8e, ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error)
0x98, 0xe7, 0xf0, 0x2e, 0xc0, 0x1c, 0x32, 0x06, 0xe5, 0xab, 0xa1, 0xb9, 0xf3, 0x9c, 0xa6, 0xd3, DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error)
0xaa, 0x0b, 0x5a, 0xb3, 0x26, 0x34, 0x04, 0xea, 0x68, 0x82, 0x03, 0x39, 0xc2, 0xa9, 0x57, 0xa4, ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error)
0xad, 0x7c, 0x88, 0x7f, 0x82, 0x17, 0x8f, 0xb5, 0x5c, 0xa3, 0xd2, 0x81, 0x9c, 0x6a, 0x81, 0xf7, }
0xec, 0x55, 0xbe, 0x40, 0x92, 0x38, 0x17, 0xe1, 0x97, 0xff, 0x3f, 0xac, 0x59, 0x07, 0xdc, 0x14,
0x26, 0xdd, 0xec, 0x2f, 0x74, 0x93, 0x6c, 0x8b, 0x8c, 0xc7, 0xff, 0x3a, 0xb0, 0xb3, 0xb4, 0xcf, type replicationClient struct {
0x3e, 0x42, 0x79, 0x30, 0x0b, 0x91, 0x04, 0x6c, 0x76, 0xf8, 0xea, 0x2c, 0xed, 0xe4, 0xd7, 0x32, cc *grpc.ClientConn
0x05, 0xf1, 0xad, 0x23, 0x17, 0xc3, 0x09, 0x26, 0x6d, 0xd3, 0xda, 0xc6, 0xce, 0xa3, 0xc0, 0xf7, }
0x4a, 0x4d, 0xa7, 0x55, 0x16, 0xb4, 0x66, 0x2f, 0xa1, 0x7e, 0xac, 0x70, 0x68, 0x70, 0x70, 0x73,
0xee, 0x95, 0x69, 0x63, 0x1e, 0x60, 0x07, 0xe0, 0x12, 0x08, 0xe4, 0xd4, 0xab, 0x50, 0xa6, 0x0c, func NewReplicationClient(cc *grpc.ClientConn) ReplicationClient {
0xf3, 0xb7, 0xd0, 0xc8, 0x95, 0x65, 0x1b, 0xe0, 0xf6, 0xa7, 0xc3, 0x50, 0xdf, 0x49, 0xb3, 0x5d, return &replicationClient{cc}
0xb0, 0xa8, 0x2b, 0xe5, 0x68, 0x32, 0x54, 0xa3, 0x6d, 0x87, 0xff, 0x76, 0xa0, 0xd6, 0xc7, 0xa9, }
0xbf, 0x86, 0xaf, 0x56, 0xe4, 0x99, 0x92, 0x93, 0x54, 0xb8, 0x5d, 0xb3, 0x4d, 0x28, 0x0e, 0x24,
0xc9, 0xae, 0x8b, 0xe2, 0x40, 0x2e, 0x5e, 0x6d, 0x79, 0xe9, 0x6a, 0x49, 0xb8, 0x9c, 0x84, 0x0a, func (c *replicationClient) ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) {
0xb5, 0x26, 0xe1, 0xae, 0xc8, 0x30, 0xdb, 0x83, 0xca, 0x09, 0xfa, 0x51, 0xe8, 0x55, 0x69, 0x23, out := new(ListFilesystemRes)
0x06, 0x6c, 0x1f, 0xaa, 0x27, 0x6a, 0x26, 0xa2, 0xa9, 0x57, 0xa3, 0x70, 0x82, 0xf8, 0x11, 0xb8, err := c.cc.Invoke(ctx, "/Replication/ListFilesystems", in, out, opts...)
0x57, 0x4a, 0x86, 0xa8, 0xcc, 0x2c, 0x33, 0xd5, 0xc9, 0x99, 0xba, 0x07, 0x95, 0xeb, 0xe1, 0x38, if err != nil {
0x4a, 0x9d, 0x8e, 0x01, 0xff, 0x95, 0x75, 0xac, 0x59, 0x0b, 0xb6, 0xbe, 0x68, 0xf4, 0xf3, 0x8a, return nil, err
0x1d, 0x2a, 0xb1, 0x18, 0x66, 0x1c, 0x36, 0x4e, 0x1f, 0x42, 0xbc, 0x35, 0xe8, 0xf7, 0x83, 0x9f, }
0x71, 0xca, 0x92, 0x78, 0x14, 0x63, 0xef, 0x01, 0x12, 0x3d, 0x01, 0x6a, 0xaf, 0x44, 0x1f, 0xd7, return out, nil
0x33, 0xfa, 0x2c, 0x52, 0x99, 0x22, 0x47, 0xe0, 0x37, 0x00, 0x02, 0x6f, 0x31, 0xf8, 0x81, 0xeb, }
0x98, 0xff, 0x0e, 0xb6, 0x8f, 0xc7, 0x38, 0x54, 0x8b, 0x83, 0xe3, 0x8a, 0xa5, 0x38, 0xdf, 0xc8,
0x65, 0xd6, 0x7c, 0x04, 0xbb, 0x27, 0xa8, 0x8d, 0x92, 0xb3, 0xf4, 0x2b, 0x58, 0x67, 0x8a, 0xd8, func (c *replicationClient) ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) {
0x11, 0xd4, 0x33, 0xbe, 0x57, 0x7c, 0x72, 0x52, 0xe6, 0x44, 0xfe, 0x15, 0xd8, 0x42, 0xb1, 0x64, out := new(ListFilesystemVersionsRes)
0xe8, 0x52, 0x48, 0x95, 0x9e, 0x18, 0xba, 0x94, 0x67, 0x6f, 0xef, 0x54, 0x29, 0xa9, 0xd2, 0xdb, err := c.cc.Invoke(ctx, "/Replication/ListFilesystemVersions", in, out, opts...)
0x23, 0xc0, 0x7b, 0xab, 0x9a, 0xb1, 0xcf, 0x54, 0xcd, 0x1a, 0x30, 0x36, 0xe9, 0x50, 0x3f, 0xa7, if err != nil {
0xfc, 0xcb, 0x52, 0x44, 0xca, 0xe3, 0x7f, 0x1c, 0xd8, 0x13, 0x18, 0x8e, 0x83, 0x5b, 0x1a, 0x9a, return nil, err
0xe3, 0x48, 0x69, 0xa9, 0xd6, 0x31, 0xe6, 0x10, 0x4a, 0xdf, 0xd1, 0x90, 0xac, 0x46, 0xe7, 0x35, }
0xd5, 0x59, 0x95, 0xa7, 0x7d, 0x8e, 0xe6, 0x32, 0xec, 0x15, 0x84, 0x65, 0xdb, 0x43, 0x1a, 0x0d, return out, nil
0x0d, 0xca, 0x93, 0x87, 0xfa, 0xe9, 0x21, 0x8d, 0xe6, 0xa0, 0x06, 0x15, 0x4a, 0x72, 0xf0, 0x06, }
0x2a, 0xb4, 0x61, 0x87, 0x27, 0x33, 0x32, 0xf6, 0x25, 0xc3, 0xdd, 0x32, 0x14, 0x65, 0xc8, 0x07,
0x2b, 0xbb, 0xb2, 0xa3, 0x15, 0xbf, 0x30, 0xb6, 0x9f, 0x72, 0xaf, 0x90, 0xbd, 0x31, 0xee, 0x85, func (c *replicationClient) DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) {
0x34, 0xf8, 0x10, 0xe8, 0x38, 0x9f, 0xdb, 0x2b, 0x88, 0x2c, 0xd2, 0x75, 0xa1, 0x1a, 0xbb, 0xf5, out := new(DestroySnapshotsRes)
0xad, 0x4a, 0x7f, 0x1e, 0x87, 0xff, 0x02, 0x00, 0x00, 0xff, 0xff, 0x66, 0x74, 0x36, 0x3a, 0x49, err := c.cc.Invoke(ctx, "/Replication/DestroySnapshots", in, out, opts...)
0x06, 0x00, 0x00, if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationClient) ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) {
out := new(ReplicationCursorRes)
err := c.cc.Invoke(ctx, "/Replication/ReplicationCursor", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ReplicationServer is the server API for Replication service.
type ReplicationServer interface {
ListFilesystems(context.Context, *ListFilesystemReq) (*ListFilesystemRes, error)
ListFilesystemVersions(context.Context, *ListFilesystemVersionsReq) (*ListFilesystemVersionsRes, error)
DestroySnapshots(context.Context, *DestroySnapshotsReq) (*DestroySnapshotsRes, error)
ReplicationCursor(context.Context, *ReplicationCursorReq) (*ReplicationCursorRes, error)
}
func RegisterReplicationServer(s *grpc.Server, srv ReplicationServer) {
s.RegisterService(&_Replication_serviceDesc, srv)
}
func _Replication_ListFilesystems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListFilesystemReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServer).ListFilesystems(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/Replication/ListFilesystems",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServer).ListFilesystems(ctx, req.(*ListFilesystemReq))
}
return interceptor(ctx, in, info, handler)
}
func _Replication_ListFilesystemVersions_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListFilesystemVersionsReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServer).ListFilesystemVersions(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/Replication/ListFilesystemVersions",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServer).ListFilesystemVersions(ctx, req.(*ListFilesystemVersionsReq))
}
return interceptor(ctx, in, info, handler)
}
func _Replication_DestroySnapshots_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DestroySnapshotsReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServer).DestroySnapshots(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/Replication/DestroySnapshots",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServer).DestroySnapshots(ctx, req.(*DestroySnapshotsReq))
}
return interceptor(ctx, in, info, handler)
}
func _Replication_ReplicationCursor_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ReplicationCursorReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServer).ReplicationCursor(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/Replication/ReplicationCursor",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServer).ReplicationCursor(ctx, req.(*ReplicationCursorReq))
}
return interceptor(ctx, in, info, handler)
}
var _Replication_serviceDesc = grpc.ServiceDesc{
ServiceName: "Replication",
HandlerType: (*ReplicationServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "ListFilesystems",
Handler: _Replication_ListFilesystems_Handler,
},
{
MethodName: "ListFilesystemVersions",
Handler: _Replication_ListFilesystemVersions_Handler,
},
{
MethodName: "DestroySnapshots",
Handler: _Replication_DestroySnapshots_Handler,
},
{
MethodName: "ReplicationCursor",
Handler: _Replication_ReplicationCursor_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pdu.proto",
}
func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_89315d819a6e0938) }
var fileDescriptor_pdu_89315d819a6e0938 = []byte{
// 735 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdd, 0x6e, 0xda, 0x4a,
0x10, 0xc6, 0x60, 0xc0, 0x0c, 0x51, 0x42, 0x36, 0x9c, 0xc8, 0xc7, 0xe7, 0x28, 0x42, 0xdb, 0x1b,
0x52, 0xa9, 0x6e, 0x45, 0x7b, 0x53, 0x55, 0xaa, 0x54, 0x42, 0x7e, 0xa4, 0x56, 0x69, 0xb4, 0xd0,
0x28, 0xca, 0x1d, 0x0d, 0xa3, 0xc4, 0x0a, 0xb0, 0xce, 0xee, 0xba, 0x0a, 0xbd, 0xec, 0x7b, 0xf4,
0x41, 0xfa, 0x0e, 0xbd, 0xec, 0x03, 0x55, 0xbb, 0x60, 0xe3, 0x60, 0x23, 0x71, 0xe5, 0xfd, 0xbe,
0x9d, 0x9d, 0x9d, 0xf9, 0x76, 0x66, 0x0c, 0xb5, 0x70, 0x14, 0xf9, 0xa1, 0xe0, 0x8a, 0xd3, 0x3d,
0xd8, 0xfd, 0x14, 0x48, 0x75, 0x12, 0x8c, 0x51, 0xce, 0xa4, 0xc2, 0x09, 0xc3, 0x07, 0x7a, 0x95,
0x25, 0x25, 0x79, 0x01, 0xf5, 0x25, 0x21, 0x5d, 0xab, 0x55, 0x6a, 0xd7, 0x3b, 0x75, 0x3f, 0x65,
0x94, 0xde, 0x27, 0x4d, 0x28, 0x1f, 0x4f, 0x42, 0x35, 0x73, 0x8b, 0x2d, 0xab, 0xed, 0xb0, 0x39,
0xa0, 0x5d, 0x80, 0xa5, 0x11, 0x21, 0x60, 0x5f, 0x0c, 0xd5, 0x9d, 0x6b, 0xb5, 0xac, 0x76, 0x8d,
0x99, 0x35, 0x69, 0x41, 0x9d, 0xa1, 0x8c, 0x26, 0x38, 0xe0, 0xf7, 0x38, 0x35, 0xa7, 0x6b, 0x2c,
0x4d, 0xd1, 0x77, 0xf0, 0xef, 0xd3, 0xe8, 0x2e, 0x51, 0xc8, 0x80, 0x4f, 0x25, 0xc3, 0x07, 0x72,
0x90, 0xbe, 0x60, 0xe1, 0x38, 0xc5, 0xd0, 0x8f, 0xeb, 0x0f, 0x4b, 0xe2, 0x83, 0x13, 0xc3, 0x45,
0x7e, 0xc4, 0xcf, 0x58, 0xb2, 0xc4, 0x86, 0xfe, 0xb1, 0x60, 0x37, 0xb3, 0x4f, 0x3a, 0x60, 0x0f,
0x66, 0x21, 0x9a, 0xcb, 0xb7, 0x3b, 0x07, 0x59, 0x0f, 0xfe, 0xe2, 0xab, 0xad, 0x98, 0xb1, 0xd5,
0x4a, 0x9c, 0x0f, 0x27, 0xb8, 0x48, 0xd7, 0xac, 0x35, 0x77, 0x1a, 0x05, 0x23, 0xb7, 0xd4, 0xb2,
0xda, 0x36, 0x33, 0x6b, 0xf2, 0x3f, 0xd4, 0x8e, 0x04, 0x0e, 0x15, 0x0e, 0xae, 0x4e, 0x5d, 0xdb,
0x6c, 0x2c, 0x09, 0xe2, 0x81, 0x63, 0x40, 0xc0, 0xa7, 0x6e, 0xd9, 0x78, 0x4a, 0x30, 0x3d, 0x84,
0x7a, 0xea, 0x5a, 0xb2, 0x05, 0x4e, 0x7f, 0x3a, 0x0c, 0xe5, 0x1d, 0x57, 0x8d, 0x82, 0x46, 0x5d,
0xce, 0xef, 0x27, 0x43, 0x71, 0xdf, 0xb0, 0xe8, 0x2f, 0x0b, 0xaa, 0x7d, 0x9c, 0x8e, 0x36, 0xd0,
0x53, 0x07, 0x79, 0x22, 0xf8, 0x24, 0x0e, 0x5c, 0xaf, 0xc9, 0x36, 0x14, 0x07, 0xdc, 0x84, 0x5d,
0x63, 0xc5, 0x01, 0x5f, 0x7d, 0x52, 0x3b, 0xf3, 0xa4, 0x26, 0x70, 0x3e, 0x09, 0x05, 0x4a, 0x69,
0x02, 0x77, 0x58, 0x82, 0x75, 0x21, 0xf5, 0x70, 0x14, 0x85, 0x6e, 0x65, 0x5e, 0x48, 0x06, 0x90,
0x7d, 0xa8, 0xf4, 0xc4, 0x8c, 0x45, 0x53, 0xb7, 0x6a, 0xe8, 0x05, 0xa2, 0x6f, 0xc0, 0xb9, 0x10,
0x3c, 0x44, 0xa1, 0x66, 0x89, 0xa8, 0x56, 0x4a, 0xd4, 0x26, 0x94, 0x2f, 0x87, 0xe3, 0x28, 0x56,
0x7a, 0x0e, 0xe8, 0x8f, 0x24, 0x63, 0x49, 0xda, 0xb0, 0xf3, 0x45, 0xe2, 0x68, 0xb5, 0x08, 0x1d,
0xb6, 0x4a, 0x13, 0x0a, 0x5b, 0xc7, 0x8f, 0x21, 0xde, 0x28, 0x1c, 0xf5, 0x83, 0xef, 0x68, 0x32,
0x2e, 0xb1, 0x27, 0x1c, 0x39, 0x04, 0x58, 0xc4, 0x13, 0xa0, 0x74, 0x6d, 0x53, 0x54, 0x35, 0x3f,
0x0e, 0x91, 0xa5, 0x36, 0xe9, 0x15, 0x00, 0xc3, 0x1b, 0x0c, 0xbe, 0xe1, 0x26, 0xc2, 0x3f, 0x87,
0xc6, 0xd1, 0x18, 0x87, 0x22, 0x1b, 0x67, 0x86, 0xa7, 0x5b, 0x29, 0xcf, 0x92, 0xde, 0xc2, 0x5e,
0x0f, 0xa5, 0x12, 0x7c, 0x16, 0x57, 0xc0, 0x26, 0x9d, 0x43, 0x5e, 0x41, 0x2d, 0xb1, 0x77, 0x8b,
0x6b, 0xbb, 0x63, 0x69, 0x44, 0xaf, 0x81, 0xac, 0x5c, 0xb4, 0x68, 0xb2, 0x18, 0x9a, 0x5b, 0xd6,
0x34, 0x59, 0x6c, 0x63, 0x06, 0x89, 0x10, 0x5c, 0xc4, 0x2f, 0x66, 0x00, 0xed, 0xe5, 0x25, 0xa1,
0x87, 0x54, 0x55, 0x27, 0x3e, 0x56, 0x71, 0x03, 0xef, 0xf9, 0xd9, 0x10, 0x58, 0x6c, 0x43, 0x7f,
0x5b, 0xd0, 0x64, 0x18, 0x8e, 0x83, 0x1b, 0xd3, 0x24, 0x47, 0x91, 0x90, 0x5c, 0x6c, 0x22, 0xc6,
0x4b, 0x28, 0xdd, 0xa2, 0x32, 0x21, 0xd5, 0x3b, 0xff, 0xf9, 0x79, 0x3e, 0xfc, 0x53, 0x54, 0x9f,
0xc3, 0xb3, 0x02, 0xd3, 0x96, 0xfa, 0x80, 0x44, 0x65, 0x4a, 0x64, 0xed, 0x81, 0x7e, 0x7c, 0x40,
0xa2, 0xf2, 0xaa, 0x50, 0x36, 0x0e, 0xbc, 0x67, 0x50, 0x36, 0x1b, 0xba, 0x49, 0x12, 0xe1, 0xe6,
0x5a, 0x24, 0xb8, 0x6b, 0x43, 0x91, 0x87, 0x74, 0x90, 0x9b, 0x8d, 0x6e, 0xa1, 0xf9, 0x24, 0xd1,
0x79, 0xd8, 0x67, 0x85, 0x64, 0x96, 0x38, 0xe7, 0x5c, 0xe1, 0x63, 0x20, 0xe7, 0xfe, 0x9c, 0xb3,
0x02, 0x4b, 0x98, 0xae, 0x03, 0x95, 0xb9, 0x4a, 0x9d, 0x9f, 0x45, 0xdd, 0xbf, 0x89, 0x5b, 0xf2,
0x16, 0x76, 0x9e, 0x8e, 0x50, 0x49, 0x88, 0x9f, 0xf9, 0x89, 0x78, 0x59, 0x4e, 0x92, 0x0b, 0xd8,
0xcf, 0x9f, 0xbe, 0xc4, 0xf3, 0xd7, 0xce, 0x74, 0x6f, 0xfd, 0x9e, 0x24, 0xef, 0xa1, 0xb1, 0x5a,
0x07, 0xa4, 0xe9, 0xe7, 0xd4, 0xb7, 0x97, 0xc7, 0x4a, 0xf2, 0x01, 0x76, 0x33, 0x92, 0x91, 0x7f,
0x72, 0xdf, 0xc7, 0xcb, 0xa5, 0x65, 0xb7, 0x7c, 0x5d, 0x0a, 0x47, 0xd1, 0xd7, 0x8a, 0xf9, 0xa1,
0xbe, 0xfe, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xba, 0x8e, 0x63, 0x5d, 0x07, 0x00, 0x00,
} }

View File

@ -1,11 +1,19 @@
syntax = "proto3"; syntax = "proto3";
option go_package = "pdu";
package pdu; service Replication {
rpc ListFilesystems (ListFilesystemReq) returns (ListFilesystemRes);
rpc ListFilesystemVersions (ListFilesystemVersionsReq) returns (ListFilesystemVersionsRes);
rpc DestroySnapshots (DestroySnapshotsReq) returns (DestroySnapshotsRes);
rpc ReplicationCursor (ReplicationCursorReq) returns (ReplicationCursorRes);
// for Send and Recv, see package rpc
}
message ListFilesystemReq {} message ListFilesystemReq {}
message ListFilesystemRes { message ListFilesystemRes {
repeated Filesystem Filesystems = 1; repeated Filesystem Filesystems = 1;
bool Empty = 2;
} }
message Filesystem { message Filesystem {
@ -60,22 +68,18 @@ message Property {
} }
message SendRes { message SendRes {
// The actual stream is in the stream part of the streamrpc response
// Whether the resume token provided in the request has been used or not. // Whether the resume token provided in the request has been used or not.
bool UsedResumeToken = 1; bool UsedResumeToken = 2;
// Expected stream size determined by dry run, not exact. // Expected stream size determined by dry run, not exact.
// 0 indicates that for the given SendReq, no size estimate could be made. // 0 indicates that for the given SendReq, no size estimate could be made.
int64 ExpectedSize = 2; int64 ExpectedSize = 3;
repeated Property Properties = 3; repeated Property Properties = 4;
} }
message ReceiveReq { message ReceiveReq {
// The stream part of the streamrpc request contains the zfs send stream string Filesystem = 1; // FIXME should be snapshot name, we can enforce that on recv
string Filesystem = 1;
// If true, the receiver should clear the resume token before perfoming the zfs recv of the stream in the request // If true, the receiver should clear the resume token before perfoming the zfs recv of the stream in the request
bool ClearResumeToken = 2; bool ClearResumeToken = 2;

View File

@ -0,0 +1,169 @@
package base2bufpool
import (
"fmt"
"math/bits"
"sync"
)
type pool struct {
mtx sync.Mutex
bufs [][]byte
shift uint
}
func (p *pool) Put(buf []byte) {
p.mtx.Lock()
defer p.mtx.Unlock()
if len(buf) != 1<<p.shift {
panic(fmt.Sprintf("implementation error: %v %v", len(buf), 1<<p.shift))
}
if len(p.bufs) > 10 { // FIXME constant
return
}
p.bufs = append(p.bufs, buf)
}
func (p *pool) Get() []byte {
p.mtx.Lock()
defer p.mtx.Unlock()
if len(p.bufs) > 0 {
ret := p.bufs[len(p.bufs)-1]
p.bufs = p.bufs[0 : len(p.bufs)-1]
return ret
}
return make([]byte, 1<<p.shift)
}
type Pool struct {
minShift uint
maxShift uint
pools []pool
onNoFit NoFitBehavior
}
type Buffer struct {
// always power of 2, from Pool.pools
shiftBuf []byte
// presentedLen
payloadLen uint
// backref too pool for Free
pool *Pool
}
func (b Buffer) Bytes() []byte {
return b.shiftBuf[0:b.payloadLen]
}
func (b *Buffer) Shrink(newPayloadLen uint) {
if newPayloadLen > b.payloadLen {
panic(fmt.Sprintf("shrink is actually an expand, invalid: %v %v", newPayloadLen, b.payloadLen))
}
b.payloadLen = newPayloadLen
}
func (b Buffer) Free() {
if b.pool != nil {
b.pool.put(b)
}
}
//go:generate enumer -type NoFitBehavior
type NoFitBehavior uint
const (
AllocateSmaller NoFitBehavior = 1 << iota
AllocateLarger
Allocate NoFitBehavior = AllocateSmaller | AllocateLarger
Panic NoFitBehavior = 0
)
func New(minShift, maxShift uint, noFitBehavior NoFitBehavior) *Pool {
if minShift > 63 || maxShift > 63 {
panic(fmt.Sprintf("{min|max}Shift are the _exponent_, got minShift=%v maxShift=%v and limit of 63, which amounts to %v bits", minShift, maxShift, uint64(1)<<63))
}
pools := make([]pool, maxShift-minShift+1)
for i := uint(0); i < uint(len(pools)); i++ {
i := i // the closure below must copy i
pools[i] = pool{
shift: minShift + i,
bufs: make([][]byte, 0, 10),
}
}
return &Pool{
minShift: minShift,
maxShift: maxShift,
pools: pools,
onNoFit: noFitBehavior,
}
}
func fittingShift(x uint) uint {
if x == 0 {
return 0
}
blen := uint(bits.Len(x))
if 1<<(blen-1) == x {
return blen - 1
}
return blen
}
func (p *Pool) handlePotentialNoFit(reqShift uint) (buf Buffer, didHandle bool) {
if reqShift == 0 {
if p.onNoFit&AllocateSmaller != 0 {
return Buffer{[]byte{}, 0, nil}, true
} else {
goto doPanic
}
}
if reqShift < p.minShift {
if p.onNoFit&AllocateSmaller != 0 {
goto alloc
} else {
goto doPanic
}
}
if reqShift > p.maxShift {
if p.onNoFit&AllocateLarger != 0 {
goto alloc
} else {
goto doPanic
}
}
return Buffer{}, false
alloc:
return Buffer{make([]byte, 1<<reqShift), 1 << reqShift, nil}, true
doPanic:
panic(fmt.Sprintf("base2bufpool: configured to panic on shift=%v (minShift=%v maxShift=%v)", reqShift, p.minShift, p.maxShift))
}
func (p *Pool) Get(minSize uint) Buffer {
shift := fittingShift(minSize)
buf, didHandle := p.handlePotentialNoFit(shift)
if didHandle {
buf.Shrink(minSize)
return buf
}
idx := int64(shift) - int64(p.minShift)
return Buffer{p.pools[idx].Get(), minSize, p}
}
func (p *Pool) put(buffer Buffer) {
if buffer.pool != p {
panic("putting buffer to pool where it didn't originate from")
}
buf := buffer.shiftBuf
if bits.OnesCount(uint(len(buf))) > 1 {
panic(fmt.Sprintf("putting buffer that is not power of two len: %v", len(buf)))
}
if len(buf) == 0 {
return
}
shift := fittingShift(uint(len(buf)))
if shift < p.minShift || shift > p.maxShift {
return // drop it
}
p.pools[shift-p.minShift].Put(buf)
}

View File

@ -0,0 +1,98 @@
package base2bufpool
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPoolAllocBehavior(t *testing.T) {
type testcase struct {
poolMinShift, poolMaxShift uint
behavior NoFitBehavior
get uint
expShiftBufLen int64 // -1 if panic expected
}
tcs := []testcase{
{
15, 20, Allocate,
1 << 14, 1 << 14,
},
{
15, 20, Allocate,
1 << 22, 1 << 22,
},
{
15, 20, Panic,
1 << 16, 1 << 16,
},
{
15, 20, Panic,
1 << 14, -1,
},
{
15, 20, Panic,
1 << 22, -1,
},
{
15, 20, Panic,
(1 << 15) + 23, 1 << 16,
},
{
15, 20, Panic,
0, -1, // yep, 0 always works, even
},
{
15, 20, Allocate,
0, 0,
},
{
15, 20, AllocateSmaller,
1 << 14, 1 << 14,
},
{
15, 20, AllocateSmaller,
1 << 22, -1,
},
}
for i := range tcs {
tc := tcs[i]
t.Run(fmt.Sprintf("[%d,%d] behav=%s Get(%d) exp=%d", tc.poolMinShift, tc.poolMaxShift, tc.behavior, tc.get, tc.expShiftBufLen), func(t *testing.T) {
pool := New(tc.poolMinShift, tc.poolMaxShift, tc.behavior)
if tc.expShiftBufLen == -1 {
assert.Panics(t, func() {
pool.Get(tc.get)
})
return
}
buf := pool.Get(tc.get)
assert.True(t, uint(len(buf.Bytes())) == tc.get)
assert.True(t, int64(len(buf.shiftBuf)) == tc.expShiftBufLen)
})
}
}
func TestFittingShift(t *testing.T) {
assert.Equal(t, uint(16), fittingShift(1+1<<15))
assert.Equal(t, uint(15), fittingShift(1<<15))
}
func TestFreeFromPoolRangeDoesNotPanic(t *testing.T) {
pool := New(15, 20, Allocate)
buf := pool.Get(1 << 16)
assert.NotPanics(t, func() {
buf.Free()
})
}
func TestFreeFromOutOfPoolRangeDoesNotPanic(t *testing.T) {
pool := New(15, 20, Allocate)
buf := pool.Get(1 << 23)
assert.NotPanics(t, func() {
buf.Free()
})
}

View File

@ -0,0 +1,51 @@
// Code generated by "enumer -type NoFitBehavior"; DO NOT EDIT.
package base2bufpool
import (
"fmt"
)
const _NoFitBehaviorName = "PanicAllocateSmallerAllocateLargerAllocate"
var _NoFitBehaviorIndex = [...]uint8{0, 5, 20, 34, 42}
func (i NoFitBehavior) String() string {
if i >= NoFitBehavior(len(_NoFitBehaviorIndex)-1) {
return fmt.Sprintf("NoFitBehavior(%d)", i)
}
return _NoFitBehaviorName[_NoFitBehaviorIndex[i]:_NoFitBehaviorIndex[i+1]]
}
var _NoFitBehaviorValues = []NoFitBehavior{0, 1, 2, 3}
var _NoFitBehaviorNameToValueMap = map[string]NoFitBehavior{
_NoFitBehaviorName[0:5]: 0,
_NoFitBehaviorName[5:20]: 1,
_NoFitBehaviorName[20:34]: 2,
_NoFitBehaviorName[34:42]: 3,
}
// NoFitBehaviorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func NoFitBehaviorString(s string) (NoFitBehavior, error) {
if val, ok := _NoFitBehaviorNameToValueMap[s]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to NoFitBehavior values", s)
}
// NoFitBehaviorValues returns all values of the enum
func NoFitBehaviorValues() []NoFitBehavior {
return _NoFitBehaviorValues
}
// IsANoFitBehavior returns "true" if the value is listed in the enum definition. "false" otherwise
func (i NoFitBehavior) IsANoFitBehavior() bool {
for _, v := range _NoFitBehaviorValues {
if i == v {
return true
}
}
return false
}

View File

@ -0,0 +1,215 @@
package dataconn
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/golang/protobuf/proto"
"github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/zfs"
)
type Client struct {
log Logger
cn transport.Connecter
}
func NewClient(connecter transport.Connecter, log Logger) *Client {
return &Client{
log: log,
cn: connecter,
}
}
func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error {
var buf bytes.Buffer
_, memErr := buf.WriteString(endpoint)
if memErr != nil {
panic(memErr)
}
if err := conn.WriteStreamedMessage(ctx, &buf, ReqHeader); err != nil {
return err
}
protobufBytes, err := proto.Marshal(req)
if err != nil {
return err
}
protobuf := bytes.NewBuffer(protobufBytes)
if err := conn.WriteStreamedMessage(ctx, protobuf, ReqStructured); err != nil {
return err
}
if streamCopier != nil {
return conn.SendStream(ctx, streamCopier, ZFSStream)
} else {
return nil
}
}
type RemoteHandlerError struct {
msg string
}
func (e *RemoteHandlerError) Error() string {
return fmt.Sprintf("server error: %s", e.msg)
}
type ProtocolError struct {
cause error
}
func (e *ProtocolError) Error() string {
return fmt.Sprintf("protocol error: %s", e)
}
func (c *Client) recv(ctx context.Context, conn *stream.Conn, res proto.Message) error {
headerBuf, err := conn.ReadStreamedMessage(ctx, ResponseHeaderMaxSize, ResHeader)
if err != nil {
return err
}
header := string(headerBuf)
if strings.HasPrefix(header, responseHeaderHandlerErrorPrefix) {
// FIXME distinguishable error type
return &RemoteHandlerError{strings.TrimPrefix(header, responseHeaderHandlerErrorPrefix)}
}
if !strings.HasPrefix(header, responseHeaderHandlerOk) {
return &ProtocolError{fmt.Errorf("invalid header: %q", header)}
}
protobuf, err := conn.ReadStreamedMessage(ctx, ResponseStructuredMaxSize, ResStructured)
if err != nil {
return err
}
if err := proto.Unmarshal(protobuf, res); err != nil {
return &ProtocolError{fmt.Errorf("cannot unmarshal structured part of response: %s", err)}
}
return nil
}
func (c *Client) getWire(ctx context.Context) (*stream.Conn, error) {
nc, err := c.cn.Connect(ctx)
if err != nil {
return nil, err
}
conn := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout)
return conn, nil
}
func (c *Client) putWire(conn *stream.Conn) {
if err := conn.Close(); err != nil {
c.log.WithError(err).Error("error closing connection")
}
}
func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
conn, err := c.getWire(ctx)
if err != nil {
return nil, nil, err
}
putWireOnReturn := true
defer func() {
if putWireOnReturn {
c.putWire(conn)
}
}()
if err := c.send(ctx, conn, EndpointSend, req, nil); err != nil {
return nil, nil, err
}
var res pdu.SendRes
if err := c.recv(ctx, conn, &res); err != nil {
return nil, nil, err
}
var copier zfs.StreamCopier = nil
if !req.DryRun {
putWireOnReturn = false
copier = &streamCopier{streamConn: conn, closeStreamOnClose: true}
}
return &res, copier, nil
}
func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
defer c.log.Info("ReqRecv returns")
conn, err := c.getWire(ctx)
if err != nil {
return nil, err
}
// send and recv response concurrently to catch early exists of remote handler
// (e.g. disk full, permission error, etc)
type recvRes struct {
res *pdu.ReceiveRes
err error
}
recvErrChan := make(chan recvRes)
go func() {
res := &pdu.ReceiveRes{}
if err := c.recv(ctx, conn, res); err != nil {
recvErrChan <- recvRes{res, err}
} else {
recvErrChan <- recvRes{res, nil}
}
}()
sendErrChan := make(chan error)
go func() {
if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil {
sendErrChan <- err
} else {
sendErrChan <- nil
}
}()
var res recvRes
var sendErr error
var cause error // one of the above
didTryClose := false
for i := 0; i < 2; i++ {
select {
case res = <-recvErrChan:
c.log.WithField("errType", fmt.Sprintf("%T", res.err)).WithError(res.err).Debug("recv goroutine returned")
if res.err != nil && cause == nil {
cause = res.err
}
case sendErr = <-sendErrChan:
c.log.WithField("errType", fmt.Sprintf("%T", sendErr)).WithError(sendErr).Debug("send goroutine returned")
if sendErr != nil && cause == nil {
cause = sendErr
}
}
if !didTryClose && (res.err != nil || sendErr != nil) {
didTryClose = true
if err := conn.Close(); err != nil {
c.log.WithError(err).Error("ReqRecv: cannot close connection, will likely block indefinitely")
}
c.log.WithError(err).Debug("ReqRecv: closed connection, should trigger other goroutine error")
}
}
if !didTryClose {
// didn't close it in above loop, so we can give it back
c.putWire(conn)
}
// if receive failed with a RemoteHandlerError, we know the transport was not broken
// => take the remote error as cause for the operation to fail
// TODO combine errors if send also failed
// (after all, send could have crashed on our side, rendering res.err a mere symptom of the cause)
if _, ok := res.err.(*RemoteHandlerError); ok {
cause = res.err
}
return res.res, cause
}

View File

@ -0,0 +1,20 @@
package dataconn
import (
"fmt"
"os"
)
var debugEnabled bool = false
func init() {
if os.Getenv("ZREPL_RPC_DATACONN_DEBUG") != "" {
debugEnabled = true
}
}
func debug(format string, args ...interface{}) {
if debugEnabled {
fmt.Fprintf(os.Stderr, "rpc/dataconn: %s\n", fmt.Sprintf(format, args...))
}
}

View File

@ -0,0 +1,178 @@
package dataconn
import (
"bytes"
"context"
"fmt"
"github.com/golang/protobuf/proto"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/zfs"
)
// WireInterceptor has a chance to exchange the context and connection on each client connection.
type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (context.Context, *transport.AuthConn)
// Handler implements the functionality that is exposed by Server to the Client.
type Handler interface {
// Send handles a SendRequest.
// The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run.
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
// Receive handles a ReceiveRequest.
// It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout
// configured in ServerConfig.Shared.IdleConnTimeout.
Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
}
type Logger = logger.Logger
type Server struct {
h Handler
wi WireInterceptor
log Logger
}
func NewServer(wi WireInterceptor, logger Logger, handler Handler) *Server {
return &Server{
h: handler,
wi: wi,
log: logger,
}
}
// Serve consumes the listener, closes it as soon as ctx is closed.
// No accept errors are returned: they are logged to the Logger passed
// to the constructor.
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
go func() {
<-ctx.Done()
s.log.Debug("context done")
if err := l.Close(); err != nil {
s.log.WithError(err).Error("cannot close listener")
}
}()
conns := make(chan *transport.AuthConn)
go func() {
for {
conn, err := l.Accept(ctx)
if err != nil {
if ctx.Done() != nil {
s.log.Debug("stop accepting after context is done")
return
}
s.log.WithError(err).Error("accept error")
continue
}
conns <- conn
}
}()
for conn := range conns {
go s.serveConn(conn)
}
}
func (s *Server) serveConn(nc *transport.AuthConn) {
s.log.Debug("serveConn begin")
defer s.log.Debug("serveConn done")
ctx := context.Background()
if s.wi != nil {
ctx, nc = s.wi(ctx, nc)
}
c := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout)
defer func() {
s.log.Debug("close client connection")
if err := c.Close(); err != nil {
s.log.WithError(err).Error("cannot close client connection")
}
}()
header, err := c.ReadStreamedMessage(ctx, RequestHeaderMaxSize, ReqHeader)
if err != nil {
s.log.WithError(err).Error("error reading structured part")
return
}
endpoint := string(header)
reqStructured, err := c.ReadStreamedMessage(ctx, RequestStructuredMaxSize, ReqStructured)
if err != nil {
s.log.WithError(err).Error("error reading structured part")
return
}
s.log.WithField("endpoint", endpoint).Debug("calling handler")
var res proto.Message
var sendStream zfs.StreamCopier
var handlerErr error
switch endpoint {
case EndpointSend:
var req pdu.SendReq
if err := proto.Unmarshal(reqStructured, &req); err != nil {
s.log.WithError(err).Error("cannot unmarshal send request")
return
}
res, sendStream, handlerErr = s.h.Send(ctx, &req) // SHADOWING
case EndpointRecv:
var req pdu.ReceiveReq
if err := proto.Unmarshal(reqStructured, &req); err != nil {
s.log.WithError(err).Error("cannot unmarshal receive request")
return
}
res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING
default:
s.log.WithField("endpoint", endpoint).Error("unknown endpoint")
handlerErr = fmt.Errorf("requested endpoint does not exist")
return
}
s.log.WithField("endpoint", endpoint).WithField("errType", fmt.Sprintf("%T", handlerErr)).Debug("handler returned")
// prepare protobuf now to return the protobuf error in the header
// if marshaling fails. We consider failed marshaling a handler error
var protobuf *bytes.Buffer
if handlerErr == nil {
protobufBytes, err := proto.Marshal(res)
if err != nil {
s.log.WithError(err).Error("cannot marshal handler protobuf")
handlerErr = err
}
protobuf = bytes.NewBuffer(protobufBytes) // SHADOWING
}
var resHeaderBuf bytes.Buffer
if handlerErr == nil {
resHeaderBuf.WriteString(responseHeaderHandlerOk)
} else {
resHeaderBuf.WriteString(responseHeaderHandlerErrorPrefix)
resHeaderBuf.WriteString(handlerErr.Error())
}
if err := c.WriteStreamedMessage(ctx, &resHeaderBuf, ResHeader); err != nil {
s.log.WithError(err).Error("cannot write response header")
return
}
if handlerErr != nil {
s.log.Debug("early exit after handler error")
return
}
if err := c.WriteStreamedMessage(ctx, protobuf, ResStructured); err != nil {
s.log.WithError(err).Error("cannot write structured part of response")
return
}
if sendStream != nil {
err := c.SendStream(ctx, sendStream, ZFSStream)
if err != nil {
s.log.WithError(err).Error("cannot write send stream")
}
}
return
}

View File

@ -0,0 +1,70 @@
package dataconn
import (
"io"
"sync"
"time"
"github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/zfs"
)
const (
EndpointSend string = "/v1/send"
EndpointRecv string = "/v1/recv"
)
const (
ReqHeader uint32 = 1 + iota
ReqStructured
ResHeader
ResStructured
ZFSStream
)
// Note that changing theses constants may break interop with other clients
// Aggressive with timing, conservative (future compatible) with buffer sizes
const (
HeartbeatInterval = 5 * time.Second
HeartbeatPeerTimeout = 10 * time.Second
RequestHeaderMaxSize = 1 << 15
RequestStructuredMaxSize = 1 << 22
ResponseHeaderMaxSize = 1 << 15
ResponseStructuredMaxSize = 1 << 23
)
// the following are protocol constants
const (
responseHeaderHandlerOk = "HANDLER OK\n"
responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n"
)
type streamCopier struct {
mtx sync.Mutex
used bool
streamConn *stream.Conn
closeStreamOnClose bool
}
// WriteStreamTo implements zfs.StreamCopier
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
s.mtx.Lock()
defer s.mtx.Unlock()
if s.used {
panic("streamCopier used mulitple times")
}
s.used = true
return s.streamConn.ReadStreamInto(w, ZFSStream)
}
// Close implements zfs.StreamCopier
func (s *streamCopier) Close() error {
// only record the close here, what we do actually depends on whether
// the streamCopier is instantiated server-side or client-side
s.mtx.Lock()
defer s.mtx.Unlock()
if s.closeStreamOnClose {
return s.streamConn.Close()
}
return nil
}

View File

@ -0,0 +1 @@
package dataconn

View File

@ -0,0 +1,346 @@
package frameconn
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
)
type FrameHeader struct {
Type uint32
PayloadLen uint32
}
// The 4 MSBs of ft are reserved for frameconn.
func IsPublicFrameType(ft uint32) bool {
return (0xf<<28)&ft == 0
}
const (
rstFrameType uint32 = 1<<28 + iota
)
func assertPublicFrameType(frameType uint32) {
if !IsPublicFrameType(frameType) {
panic(fmt.Sprintf("frameconn: frame type %v cannot be used by consumers of this package", frameType))
}
}
func (f *FrameHeader) Unmarshal(buf []byte) {
if len(buf) != 8 {
panic(fmt.Sprintf("frame header is 8 bytes long"))
}
f.Type = binary.BigEndian.Uint32(buf[0:4])
f.PayloadLen = binary.BigEndian.Uint32(buf[4:8])
}
type Conn struct {
readMtx, writeMtx sync.Mutex
nc timeoutconn.Conn
ncBuf *bufio.ReadWriter
readNextValid bool
readNext FrameHeader
nextReadErr error
bufPool *base2bufpool.Pool // no need for sync around it
shutdown shutdownFSM
}
func Wrap(nc timeoutconn.Conn) *Conn {
return &Conn{
nc: nc,
// ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)),
bufPool: base2bufpool.New(15, 22, base2bufpool.Allocate), // FIXME switch to Panic, but need to enforce the limits in recv for that. => need frameconn config
readNext: FrameHeader{},
readNextValid: false,
}
}
var ErrReadFrameLengthShort = errors.New("read frame length too short")
var ErrFixedFrameLengthMismatch = errors.New("read frame length mismatch")
type Buffer struct {
bufpoolBuffer base2bufpool.Buffer
payloadLen uint32
}
func (b *Buffer) Free() {
b.bufpoolBuffer.Free()
}
func (b *Buffer) Bytes() []byte {
return b.bufpoolBuffer.Bytes()[0:b.payloadLen]
}
type Frame struct {
Header FrameHeader
Buffer Buffer
}
var ErrShutdown = fmt.Errorf("frameconn: shutting down")
// ReadFrame reads a frame from the connection.
//
// Due to an internal optimization (Readv, specifically), it is not guaranteed that a single call to
// WriteFrame unblocks a pending ReadFrame on an otherwise idle (empty) connection.
// The only way to guarantee that all previously written frames can reach the peer's layers on top
// of frameconn is to send an empty frame (no payload) and to ignore empty frames on the receiving side.
func (c *Conn) ReadFrame() (Frame, error) {
if c.shutdown.IsShuttingDown() {
return Frame{}, ErrShutdown
}
// only aquire readMtx now to prioritize the draining in Shutdown()
// over external callers (= drain public callers)
c.readMtx.Lock()
defer c.readMtx.Unlock()
f, err := c.readFrame()
if f.Header.Type == rstFrameType {
c.shutdown.Begin()
return Frame{}, ErrShutdown
}
return f, err
}
// callers must have readMtx locked
func (c *Conn) readFrame() (Frame, error) {
if c.nextReadErr != nil {
ret := c.nextReadErr
c.nextReadErr = nil
return Frame{}, ret
}
if !c.readNextValid {
var buf [8]byte
if _, err := io.ReadFull(c.nc, buf[:]); err != nil {
return Frame{}, err
}
c.readNext.Unmarshal(buf[:])
c.readNextValid = true
}
// read payload + next header
var nextHdrBuf [8]byte
buffer := c.bufPool.Get(uint(c.readNext.PayloadLen))
bufferBytes := buffer.Bytes()
if c.readNext.PayloadLen == 0 {
// This if statement implements the unlock-by-sending-empty-frame behavior
// documented in ReadFrame's public docs.
//
// It is crucial that we return this empty frame now:
// Consider the following plot with x-axis being time,
// P being a frame with payload, E one without, X either of P or E
//
// P P P P P P P E.....................X
// | | | |
// | | | F3
// | | |
// | F2 |signficant time between frames because
// F1 the peer has nothing to say to us
//
// Assume we're at the point were F2's header is in c.readNext.
// That means F2 has not yet been returned.
// But because it is empty (no payload), we're already done reading it.
// If we omitted this if statement, the following would happen:
// Readv below would read [][]byte{[len(0)], [len(8)]).
c.readNextValid = false
frame := Frame{
Header: c.readNext,
Buffer: Buffer{
bufpoolBuffer: buffer,
payloadLen: c.readNext.PayloadLen, // 0
},
}
return frame, nil
}
noNextHeader := false
if n, err := c.nc.ReadvFull([][]byte{bufferBytes, nextHdrBuf[:]}); err != nil {
noNextHeader = true
zeroPayloadAndPeerClosed := n == 0 && c.readNext.PayloadLen == 0 && err == io.EOF
zeroPayloadAndNextFrameHeaderThenPeerClosed := err == io.EOF && c.readNext.PayloadLen == 0 && n == int64(len(nextHdrBuf))
nonzeroPayloadRecvdButNextHeaderMissing := n > 0 && uint32(n) == c.readNext.PayloadLen
if zeroPayloadAndPeerClosed || zeroPayloadAndNextFrameHeaderThenPeerClosed || nonzeroPayloadRecvdButNextHeaderMissing {
// This is the last frame on the conn.
// Store the error to be returned on the next invocation of ReadFrame.
c.nextReadErr = err
// NORETURN, this frame is still valid
} else {
return Frame{}, err
}
}
frame := Frame{
Header: c.readNext,
Buffer: Buffer{
bufpoolBuffer: buffer,
payloadLen: c.readNext.PayloadLen,
},
}
if !noNextHeader {
c.readNext.Unmarshal(nextHdrBuf[:])
c.readNextValid = true
} else {
c.readNextValid = false
}
return frame, nil
}
func (c *Conn) WriteFrame(payload []byte, frameType uint32) error {
assertPublicFrameType(frameType)
if c.shutdown.IsShuttingDown() {
return ErrShutdown
}
c.writeMtx.Lock()
defer c.writeMtx.Unlock()
return c.writeFrame(payload, frameType)
}
func (c *Conn) writeFrame(payload []byte, frameType uint32) error {
var hdrBuf [8]byte
binary.BigEndian.PutUint32(hdrBuf[0:4], frameType)
binary.BigEndian.PutUint32(hdrBuf[4:8], uint32(len(payload)))
bufs := net.Buffers([][]byte{hdrBuf[:], payload})
if _, err := c.nc.WritevFull(bufs); err != nil {
return err
}
return nil
}
func (c *Conn) Shutdown(deadline time.Time) error {
// TCP connection teardown is a bit wonky if we are in a situation
// where there is still data in flight (DIF) to our side:
// If we just close the connection, our kernel will send RSTs
// in response to the DIF, and those RSTs may reach the client's
// kernel faster than the client app is able to pull the
// last bytes from its kernel TCP receive buffer.
//
// Therefore, we send a frame with type rstFrameType to indicate
// that the connection is to be closed immediately, and further
// use CloseWrite instead of Close.
// As per definition of the wire interface, CloseWrite guarantees
// delivery of the data in our kernel TCP send buffer.
// Therefore, the client always receives the RST frame.
//
// Now what are we going to do after that?
//
// 1. Naive Option: We just call Close() right after CloseWrite.
// This yields the same race condition as explained above (DIF, first
// paragraph): The situation just becomae a little more unlikely because
// our rstFrameType + CloseWrite dance gave the client a full RTT worth of
// time to read the data from its TCP recv buffer.
//
// 2. Correct Option: Drain the read side until io.EOF
// We can read from the unclosed read-side of the connection until we get
// the io.EOF caused by the (well behaved) client closing the connection
// in response to it reading the rstFrameType frame we sent.
// However, this wastes resources on our side (we don't care about the
// pending data anymore), and has potential for (D)DoS through CPU-time
// exhaustion if the client just keeps sending data.
// Then again, this option has the advantage with well-behaved clients
// that we do not waste precious kernel-memory on the stale receive buffer
// on our side (which is still full of data that we do not intend to read).
//
// 2.1 DoS Mitigation: Bound the number of bytes to drain, then close
// At the time of writing, this technique is practiced by the Go http server
// implementation, and actually SHOULDed in the HTTP 1.1 RFC. It is
// important to disable the idle timeout of the underlying timeoutconn in
// that case and set an absolute deadline by which the socket must have
// been fully drained. Not too hard, though ;)
//
// 2.2: Client sends RST, not FIN when it receives an rstFrameTyp frame.
// We can use wire.(*net.TCPConn).SetLinger(0) to force an RST to be sent
// on a subsequent close (instead of a FIN + wait for FIN+ACK).
// TODO put this into Wire interface as an abstract method.
//
// 2.3 Only start draining after N*RTT
// We have an RTT approximation from Wire.CloseWrite, which by definition
// must not return before all to-be-sent-data has been acknowledged by the
// client. Give the client a fair chance to react, and only start draining
// after a multiple of the RTT has elapsed.
// We waste the recv buffer memory a little longer than necessary, iff the
// client reacts faster than expected. But we don't wast CPU time.
// If we apply 2.2, we'll also have the benefit that our kernel will have
// dropped the recv buffer memory as soon as it receives the client's RST.
//
// 3. TCP-only: OOB-messaging
// We can use TCP's 'urgent' flag in the client to acknowledge the receipt
// of the rstFrameType to us.
// We can thus wait for that signal while leaving the kernel buffer as is.
// TODO: For now, we just drain the connection (Option 2),
// but we enforce deadlines so the _time_ we drain the connection
// is bounded, although we do _that_ at full speed
defer prometheus.NewTimer(prom.ShutdownSeconds).ObserveDuration()
closeWire := func(step string) error {
// TODO SetLinger(0) or similiar (we want RST frames here, not FINS)
if closeErr := c.nc.Close(); closeErr != nil {
prom.ShutdownCloseErrors.WithLabelValues("close").Inc()
return closeErr
}
return nil
}
hardclose := func(err error, step string) error {
prom.ShutdownHardCloses.WithLabelValues(step).Inc()
return closeWire(step)
}
c.shutdown.Begin()
// new calls to c.ReadFrame and c.WriteFrame will now return ErrShutdown
// Aquiring writeMtx and readMtx ensures that the last calls exit successfully
// disable renewing timeouts now, enforce the requested deadline instead
// we need to do this before aquiring locks to enforce the timeout on slow
// clients / if something hangs (DoS mitigation)
if err := c.nc.DisableTimeouts(); err != nil {
return hardclose(err, "disable_timeouts")
}
if err := c.nc.SetDeadline(deadline); err != nil {
return hardclose(err, "set_deadline")
}
c.writeMtx.Lock()
defer c.writeMtx.Unlock()
if err := c.writeFrame([]byte{}, rstFrameType); err != nil {
return hardclose(err, "write_frame")
}
if err := c.nc.CloseWrite(); err != nil {
return hardclose(err, "close_write")
}
c.readMtx.Lock()
defer c.readMtx.Unlock()
// TODO DoS mitigation: wait for client acknowledgement that they initiated Shutdown,
// then perform abortive close on our side. As explained above, probably requires
// OOB signaling such as TCP's urgent flag => transport-specific?
// TODO DoS mitigation by reading limited number of bytes
// see discussion above why this is non-trivial
defer prometheus.NewTimer(prom.ShutdownDrainSeconds).ObserveDuration()
n, _ := io.Copy(ioutil.Discard, c.nc)
prom.ShutdownDrainBytesRead.Observe(float64(n))
return closeWire("close")
}

View File

@ -0,0 +1,63 @@
package frameconn
import "github.com/prometheus/client_golang/prometheus"
var prom struct {
ShutdownDrainBytesRead prometheus.Summary
ShutdownSeconds prometheus.Summary
ShutdownDrainSeconds prometheus.Summary
ShutdownHardCloses *prometheus.CounterVec
ShutdownCloseErrors *prometheus.CounterVec
}
func init() {
prom.ShutdownDrainBytesRead = prometheus.NewSummary(prometheus.SummaryOpts{
Namespace: "zrepl",
Subsystem: "frameconn",
Name: "shutdown_drain_bytes_read",
Help: "Number of bytes read during the drain phase of connection shutdown",
})
prom.ShutdownSeconds = prometheus.NewSummary(prometheus.SummaryOpts{
Namespace: "zrepl",
Subsystem: "frameconn",
Name: "shutdown_seconds",
Help: "Seconds it took for connection shutdown to complete",
})
prom.ShutdownDrainSeconds = prometheus.NewSummary(prometheus.SummaryOpts{
Namespace: "zrepl",
Subsystem: "frameconn",
Name: "shutdown_drain_seconds",
Help: "Seconds it took from read-side-drain until shutdown completion",
})
prom.ShutdownHardCloses = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "zrepl",
Subsystem: "frameconn",
Name: "shutdown_hard_closes",
Help: "Number of hard connection closes during shutdown (abortive close)",
}, []string{"step"})
prom.ShutdownCloseErrors = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "zrepl",
Subsystem: "frameconn",
Name: "shutdown_close_errors",
Help: "Number of errors closing the underlying network connection. Should alert on this",
}, []string{"step"})
}
func PrometheusRegister(registry prometheus.Registerer) error {
if err := registry.Register(prom.ShutdownDrainBytesRead); err != nil {
return err
}
if err := registry.Register(prom.ShutdownSeconds); err != nil {
return err
}
if err := registry.Register(prom.ShutdownDrainSeconds); err != nil {
return err
}
if err := registry.Register(prom.ShutdownHardCloses); err != nil {
return err
}
if err := registry.Register(prom.ShutdownCloseErrors); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,37 @@
package frameconn
import "sync"
type shutdownFSM struct {
mtx sync.Mutex
state shutdownFSMState
}
type shutdownFSMState uint32
const (
shutdownStateOpen shutdownFSMState = iota
shutdownStateBegin
)
func newShutdownFSM() *shutdownFSM {
fsm := &shutdownFSM{
state: shutdownStateOpen,
}
return fsm
}
func (f *shutdownFSM) Begin() (thisCallStartedShutdown bool) {
f.mtx.Lock()
defer f.mtx.Unlock()
thisCallStartedShutdown = f.state != shutdownStateOpen
f.state = shutdownStateBegin
return thisCallStartedShutdown
}
func (f *shutdownFSM) IsShuttingDown() bool {
f.mtx.Lock()
defer f.mtx.Unlock()
return f.state != shutdownStateOpen
}

View File

@ -0,0 +1,22 @@
package frameconn
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsPublicFrameType(t *testing.T) {
for i := uint32(0); i < 256; i++ {
i := i
t.Run(fmt.Sprintf("^%d", i), func(t *testing.T) {
assert.False(t, IsPublicFrameType(^i))
})
}
assert.True(t, IsPublicFrameType(0))
assert.True(t, IsPublicFrameType(1))
assert.True(t, IsPublicFrameType(255))
assert.False(t, IsPublicFrameType(rstFrameType))
}

View File

@ -0,0 +1,137 @@
package heartbeatconn
import (
"fmt"
"net"
"sync/atomic"
"time"
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
)
type Conn struct {
state state
// if not nil, opErr is returned for ReadFrame and WriteFrame (not for Close, though)
opErr atomic.Value // error
fc *frameconn.Conn
sendInterval, timeout time.Duration
stopSend chan struct{}
lastFrameSent atomic.Value // time.Time
}
type HeartbeatTimeout struct{}
func (e HeartbeatTimeout) Error() string {
return "heartbeat timeout"
}
func (e HeartbeatTimeout) Temporary() bool { return true }
func (e HeartbeatTimeout) Timeout() bool { return true }
var _ net.Error = HeartbeatTimeout{}
type state = int32
const (
stateInitial state = 0
stateClosed state = 2
)
const (
heartbeat uint32 = 1 << 24
)
// The 4 MSBs of ft are reserved for frameconn, we reserve the next 4 MSB for us.
func IsPublicFrameType(ft uint32) bool {
return frameconn.IsPublicFrameType(ft) && (0xf<<24)&ft == 0
}
func assertPublicFrameType(frameType uint32) {
if !IsPublicFrameType(frameType) {
panic(fmt.Sprintf("heartbeatconn: frame type %v cannot be used by consumers of this package", frameType))
}
}
func Wrap(nc timeoutconn.Wire, sendInterval, timeout time.Duration) *Conn {
c := &Conn{
fc: frameconn.Wrap(timeoutconn.Wrap(nc, timeout)),
stopSend: make(chan struct{}),
sendInterval: sendInterval,
timeout: timeout,
}
c.lastFrameSent.Store(time.Now())
go c.sendHeartbeats()
return c
}
func (c *Conn) Shutdown() error {
normalClose := atomic.CompareAndSwapInt32(&c.state, stateInitial, stateClosed)
if normalClose {
close(c.stopSend)
}
return c.fc.Shutdown(time.Now().Add(c.timeout))
}
// started as a goroutine in constructor
func (c *Conn) sendHeartbeats() {
sleepTime := func(now time.Time) time.Duration {
lastSend := c.lastFrameSent.Load().(time.Time)
return lastSend.Add(c.sendInterval).Sub(now)
}
timer := time.NewTimer(sleepTime(time.Now()))
defer timer.Stop()
for {
select {
case <-c.stopSend:
return
case now := <-timer.C:
func() {
defer func() {
timer.Reset(sleepTime(time.Now()))
}()
if sleepTime(now) > 0 {
return
}
debug("send heartbeat")
// if the connection is in zombie mode (aka iptables DROP inbetween peers)
// this call or one of its successors will block after filling up the kernel tx buffer
c.fc.WriteFrame([]byte{}, heartbeat)
// ignore errors from WriteFrame to rate-limit SendHeartbeat retries
c.lastFrameSent.Store(time.Now())
}()
}
}
}
func (c *Conn) ReadFrame() (frameconn.Frame, error) {
return c.readFrameFiltered()
}
func (c *Conn) readFrameFiltered() (frameconn.Frame, error) {
for {
f, err := c.fc.ReadFrame()
if err != nil {
return frameconn.Frame{}, err
}
if IsPublicFrameType(f.Header.Type) {
return f, nil
}
if f.Header.Type != heartbeat {
return frameconn.Frame{}, fmt.Errorf("unknown frame type %x", f.Header.Type)
}
// drop heartbeat frame
debug("received heartbeat")
continue
}
}
func (c *Conn) WriteFrame(payload []byte, frameType uint32) error {
assertPublicFrameType(frameType)
err := c.fc.WriteFrame(payload, frameType)
if err == nil {
c.lastFrameSent.Store(time.Now())
}
return err
}

View File

@ -0,0 +1,20 @@
package heartbeatconn
import (
"fmt"
"os"
)
var debugEnabled bool = false
func init() {
if os.Getenv("ZREPL_RPC_DATACONN_HEARTBEATCONN_DEBUG") != "" {
debugEnabled = true
}
}
func debug(format string, args ...interface{}) {
if debugEnabled {
fmt.Fprintf(os.Stderr, "rpc/dataconn/heartbeatconn: %s\n", fmt.Sprintf(format, args...))
}
}

View File

@ -0,0 +1,26 @@
package heartbeatconn
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
)
func TestFrameTypes(t *testing.T) {
assert.True(t, frameconn.IsPublicFrameType(heartbeat))
}
func TestNegativeTimer(t *testing.T) {
timer := time.NewTimer(-1 * time.Second)
defer timer.Stop()
time.Sleep(100 * time.Millisecond)
select {
case <-timer.C:
t.Log("timer with negative time fired, that's what we want")
default:
t.Fail()
}
}

View File

@ -0,0 +1,184 @@
// microbenchmark to manually test rpc/dataconn perforamnce
//
// With stdin / stdout on client and server, simulating zfs send|recv piping
//
// ./microbenchmark -appmode server | pv -r > /dev/null
// ./microbenchmark -appmode client -direction recv < /dev/zero
//
//
// Without the overhead of pipes (just protocol perforamnce, mostly useful with perf bc no bw measurement)
//
// ./microbenchmark -appmode client -direction recv -devnoopWriter -devnoopReader
// ./microbenchmark -appmode server -devnoopReader -devnoopWriter
//
package main
import (
"context"
"flag"
"fmt"
"io"
"net"
"os"
"github.com/pkg/profile"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/rpc/dataconn"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/util/devnoop"
"github.com/zrepl/zrepl/zfs"
)
func orDie(err error) {
if err != nil {
panic(err)
}
}
type readerStreamCopier struct{ io.Reader }
func (readerStreamCopier) Close() error { return nil }
type readerStreamCopierErr struct {
error
}
func (readerStreamCopierErr) IsReadError() bool { return false }
func (readerStreamCopierErr) IsWriteError() bool { return true }
func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
var buf [1 << 21]byte
_, err := io.CopyBuffer(w, c.Reader, buf[:])
// always assume write error
return readerStreamCopierErr{err}
}
type devNullHandler struct{}
func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
var res pdu.SendRes
if args.devnoopReader {
return &res, readerStreamCopier{devnoop.Get()}, nil
} else {
return &res, readerStreamCopier{os.Stdin}, nil
}
}
func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) {
var out io.Writer = os.Stdout
if args.devnoopWriter {
out = devnoop.Get()
}
err := stream.WriteStreamTo(out)
var res pdu.ReceiveRes
return &res, err
}
type tcpConnecter struct {
addr string
}
func (c tcpConnecter) Connect(ctx context.Context) (timeoutconn.Wire, error) {
conn, err := net.Dial("tcp", c.addr)
if err != nil {
return nil, err
}
return conn.(*net.TCPConn), nil
}
type tcpListener struct {
nl *net.TCPListener
clientIdent string
}
func (l tcpListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
tcpconn, err := l.nl.AcceptTCP()
orDie(err)
return transport.NewAuthConn(tcpconn, l.clientIdent), nil
}
func (l tcpListener) Addr() net.Addr { return l.nl.Addr() }
func (l tcpListener) Close() error { return l.nl.Close() }
var args struct {
addr string
appmode string
direction string
profile bool
devnoopReader bool
devnoopWriter bool
}
func server() {
log := logger.NewStderrDebugLogger()
log.Debug("starting server")
nl, err := net.Listen("tcp", args.addr)
orDie(err)
l := tcpListener{nl.(*net.TCPListener), "fakeclientidentity"}
srv := dataconn.NewServer(nil, logger.NewStderrDebugLogger(), devNullHandler{})
ctx := context.Background()
srv.Serve(ctx, l)
}
func main() {
flag.BoolVar(&args.profile, "profile", false, "")
flag.BoolVar(&args.devnoopReader, "devnoopReader", false, "")
flag.BoolVar(&args.devnoopWriter, "devnoopWriter", false, "")
flag.StringVar(&args.addr, "address", ":8888", "")
flag.StringVar(&args.appmode, "appmode", "client|server", "")
flag.StringVar(&args.direction, "direction", "", "send|recv")
flag.Parse()
if args.profile {
defer profile.Start(profile.CPUProfile).Stop()
}
switch args.appmode {
case "client":
client()
case "server":
server()
default:
orDie(fmt.Errorf("unknown appmode %q", args.appmode))
}
}
func client() {
logger := logger.NewStderrDebugLogger()
ctx := context.Background()
connecter := tcpConnecter{args.addr}
client := dataconn.NewClient(connecter, logger)
switch args.direction {
case "send":
req := pdu.SendReq{}
_, stream, err := client.ReqSend(ctx, &req)
orDie(err)
err = stream.WriteStreamTo(os.Stdout)
orDie(err)
case "recv":
var r io.Reader = os.Stdin
if args.devnoopReader {
r = devnoop.Get()
}
s := readerStreamCopier{r}
req := pdu.ReceiveReq{}
_, err := client.ReqRecv(ctx, &req, &s)
orDie(err)
default:
orDie(fmt.Errorf("unknown direction%q", args.direction))
}
}

View File

@ -0,0 +1,269 @@
package stream
import (
"bytes"
"context"
"fmt"
"io"
"net"
"strings"
"unicode/utf8"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
"github.com/zrepl/zrepl/zfs"
)
type Logger = logger.Logger
type contextKey int
const (
contextKeyLogger contextKey = 1 + iota
)
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLogger, log)
}
func getLog(ctx context.Context) Logger {
log, ok := ctx.Value(contextKeyLogger).(Logger)
if !ok {
log = logger.NewNullLogger()
}
return log
}
// Frame types used by this package.
// 4 MSBs are reserved for frameconn, next 4 MSB for heartbeatconn, next 4 MSB for us.
const (
StreamErrTrailer uint32 = 1 << (16 + iota)
End
// max 16
)
// NOTE: make sure to add a tests for each frame type that checks
// whether it is heartbeatconn.IsPublicFrameType()
// Check whether the given frame type is allowed to be used by
// consumers of this package. Intended for use in unit tests.
func IsPublicFrameType(ft uint32) bool {
return frameconn.IsPublicFrameType(ft) && heartbeatconn.IsPublicFrameType(ft) && ((0xf<<16)&ft == 0)
}
const FramePayloadShift = 19
var bufpool = base2bufpool.New(FramePayloadShift, FramePayloadShift, base2bufpool.Panic)
// if sendStream returns an error, that error will be sent as a trailer to the client
// ok will return nil, though.
func writeStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) {
debug("writeStream: enter stype=%v", stype)
defer debug("writeStream: return")
if stype == 0 {
panic("stype must be non-zero")
}
if !IsPublicFrameType(stype) {
panic(fmt.Sprintf("stype %v is not public", stype))
}
return doWriteStream(ctx, c, stream, stype)
}
func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) {
// RULE1 (buf == <zero>) XOR (err == nil)
type read struct {
buf base2bufpool.Buffer
err error
}
reads := make(chan read, 5)
go func() {
for {
buffer := bufpool.Get(1 << FramePayloadShift)
bufferBytes := buffer.Bytes()
n, err := io.ReadFull(stream, bufferBytes)
buffer.Shrink(uint(n))
// if we received anything, send one read without an error (RULE 1)
if n > 0 {
reads <- read{buffer, nil}
}
if err == io.ErrUnexpectedEOF {
// happens iff io.ReadFull read io.EOF from stream
err = io.EOF
}
if err != nil {
reads <- read{err: err} // RULE1
close(reads)
return
}
}
}()
for read := range reads {
if read.err == nil {
// RULE 1: read.buf is valid
// next line is the hot path...
writeErr := c.WriteFrame(read.buf.Bytes(), stype)
read.buf.Free()
if writeErr != nil {
return nil, writeErr
}
continue
} else if read.err == io.EOF {
if err := c.WriteFrame([]byte{}, End); err != nil {
return nil, err
}
break
} else {
errReader := strings.NewReader(read.err.Error())
errReadErrReader, errConnWrite := doWriteStream(ctx, c, errReader, StreamErrTrailer)
if errReadErrReader != nil {
panic(errReadErrReader) // in-memory, cannot happen
}
return read.err, errConnWrite
}
}
return nil, nil
}
type ReadStreamErrorKind int
const (
ReadStreamErrorKindConn ReadStreamErrorKind = 1 + iota
ReadStreamErrorKindWrite
ReadStreamErrorKindSource
ReadStreamErrorKindStreamErrTrailerEncoding
ReadStreamErrorKindUnexpectedFrameType
)
type ReadStreamError struct {
Kind ReadStreamErrorKind
Err error
}
func (e *ReadStreamError) Error() string {
kindStr := ""
switch e.Kind {
case ReadStreamErrorKindConn:
kindStr = " read error: "
case ReadStreamErrorKindWrite:
kindStr = " write error: "
case ReadStreamErrorKindSource:
kindStr = " source error: "
case ReadStreamErrorKindStreamErrTrailerEncoding:
kindStr = " source implementation error: "
case ReadStreamErrorKindUnexpectedFrameType:
kindStr = " protocol error: "
}
return fmt.Sprintf("stream:%s%s", kindStr, e.Err)
}
var _ net.Error = &ReadStreamError{}
func (e ReadStreamError) netErr() net.Error {
if netErr, ok := e.Err.(net.Error); ok {
return netErr
}
return nil
}
func (e ReadStreamError) Timeout() bool {
if netErr := e.netErr(); netErr != nil {
return netErr.Timeout()
}
return false
}
func (e ReadStreamError) Temporary() bool {
if netErr := e.netErr(); netErr != nil {
return netErr.Temporary()
}
return false
}
var _ zfs.StreamCopierError = &ReadStreamError{}
func (e ReadStreamError) IsReadError() bool {
return e.Kind != ReadStreamErrorKindWrite
}
func (e ReadStreamError) IsWriteError() bool {
return e.Kind == ReadStreamErrorKindWrite
}
type readFrameResult struct {
f frameconn.Frame
err error
}
func readFrames(reads chan<- readFrameResult, c *heartbeatconn.Conn) {
for {
var r readFrameResult
r.f, r.err = c.ReadFrame()
reads <- r
if r.err != nil {
return
}
}
}
// ReadStream will close c if an error reading from c or writing to receiver occurs
//
// readStream calls itself recursively to read multi-frame error trailers
// Thus, the reads channel needs to be a parameter.
func readStream(reads <-chan readFrameResult, c *heartbeatconn.Conn, receiver io.Writer, stype uint32) *ReadStreamError {
var f frameconn.Frame
for read := range reads {
debug("readStream: read frame %v %v", read.f.Header, read.err)
f = read.f
if read.err != nil {
return &ReadStreamError{ReadStreamErrorKindConn, read.err}
}
if f.Header.Type != stype {
break
}
n, err := receiver.Write(f.Buffer.Bytes())
if err != nil {
f.Buffer.Free()
return &ReadStreamError{ReadStreamErrorKindWrite, err} // FIXME wrap as writer error
}
if n != len(f.Buffer.Bytes()) {
f.Buffer.Free()
return &ReadStreamError{ReadStreamErrorKindWrite, io.ErrShortWrite}
}
f.Buffer.Free()
}
if f.Header.Type == End {
debug("readStream: End reached")
return nil
}
if f.Header.Type == StreamErrTrailer {
debug("readStream: begin of StreamErrTrailer")
var errBuf bytes.Buffer
if n, err := errBuf.Write(f.Buffer.Bytes()); n != len(f.Buffer.Bytes()) || err != nil {
panic(fmt.Sprintf("unexpected bytes.Buffer write error: %v %v", n, err))
}
// recursion ftw! we won't enter this if stmt because stype == StreamErrTrailer in the following call
rserr := readStream(reads, c, &errBuf, StreamErrTrailer)
if rserr != nil && rserr.Kind == ReadStreamErrorKindWrite {
panic(fmt.Sprintf("unexpected bytes.Buffer write error: %s", rserr))
} else if rserr != nil {
debug("readStream: rserr != nil && != ReadStreamErrorKindWrite: %v %v\n", rserr.Kind, rserr.Err)
return rserr
}
if !utf8.Valid(errBuf.Bytes()) {
return &ReadStreamError{ReadStreamErrorKindStreamErrTrailerEncoding, fmt.Errorf("source error, but not encoded as UTF-8")}
}
return &ReadStreamError{ReadStreamErrorKindSource, fmt.Errorf("%s", errBuf.String())}
}
return &ReadStreamError{ReadStreamErrorKindUnexpectedFrameType, fmt.Errorf("unexpected frame type %v (expected %v)", f.Header.Type, stype)}
}

View File

@ -0,0 +1,194 @@
package stream
import (
"bytes"
"context"
"fmt"
"io"
"sync"
"time"
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
"github.com/zrepl/zrepl/zfs"
)
type Conn struct {
hc *heartbeatconn.Conn
// whether the per-conn readFrames goroutine completed
waitReadFramesDone chan struct{}
// filled by per-conn readFrames goroutine
frameReads chan readFrameResult
// readMtx serializes read stream operations because we inherently only
// support a single stream at a time over hc.
readMtx sync.Mutex
readClean bool
allowWriteStreamTo bool
// writeMtx serializes write stream operations because we inherently only
// support a single stream at a time over hc.
writeMtx sync.Mutex
writeClean bool
}
var readMessageSentinel = fmt.Errorf("read stream complete")
type writeStreamToErrorUnknownState struct{}
func (e writeStreamToErrorUnknownState) Error() string {
return "dataconn read stream: connection is in unknown state"
}
func (e writeStreamToErrorUnknownState) IsReadError() bool { return true }
func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false }
func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn {
hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout)
conn := &Conn{
hc: hc, readClean: true, writeClean: true,
waitReadFramesDone: make(chan struct{}),
frameReads: make(chan readFrameResult, 5), // FIXME constant
}
go conn.readFrames()
return conn
}
func isConnCleanAfterRead(res *ReadStreamError) bool {
return res == nil || res.Kind == ReadStreamErrorKindSource || res.Kind == ReadStreamErrorKindStreamErrTrailerEncoding
}
func isConnCleanAfterWrite(err error) bool {
return err == nil
}
var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped")
func (c *Conn) readFrames() {
defer close(c.waitReadFramesDone)
defer close(c.frameReads)
readFrames(c.frameReads, c.hc)
}
func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) {
c.readMtx.Lock()
defer c.readMtx.Unlock()
if !c.readClean {
return nil, &ReadStreamError{
Kind: ReadStreamErrorKindConn,
Err: fmt.Errorf("dataconn read message: connection is in unknown state"),
}
}
r, w := io.Pipe()
var buf bytes.Buffer
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lr := io.LimitReader(r, int64(maxSize))
if _, err := io.Copy(&buf, lr); err != nil && err != readMessageSentinel {
panic(err)
}
}()
err := readStream(c.frameReads, c.hc, w, frameType)
c.readClean = isConnCleanAfterRead(err)
w.CloseWithError(readMessageSentinel)
wg.Wait()
if err != nil {
return nil, err
} else {
return buf.Bytes(), nil
}
}
// WriteStreamTo reads a stream from Conn and writes it to w.
func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError {
c.readMtx.Lock()
defer c.readMtx.Unlock()
if !c.readClean {
return writeStreamToErrorUnknownState{}
}
var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
c.readClean = isConnCleanAfterRead(err)
// https://golang.org/doc/faq#nil_error
if err == nil {
return nil
}
return err
}
func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error {
c.writeMtx.Lock()
defer c.writeMtx.Unlock()
if !c.writeClean {
return fmt.Errorf("dataconn write message: connection is in unknown state")
}
errBuf, errConn := writeStream(ctx, c.hc, buf, frameType)
if errBuf != nil {
panic(errBuf)
}
c.writeClean = isConnCleanAfterWrite(errConn)
return errConn
}
func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error {
c.writeMtx.Lock()
defer c.writeMtx.Unlock()
if !c.writeClean {
return fmt.Errorf("dataconn send stream: connection is in unknown state")
}
// avoid io.Pipe if zfs.StreamCopier is an io.Reader
var r io.Reader
var w *io.PipeWriter
streamCopierErrChan := make(chan zfs.StreamCopierError, 1)
if reader, ok := src.(io.Reader); ok {
r = reader
streamCopierErrChan <- nil
close(streamCopierErrChan)
} else {
r, w = io.Pipe()
go func() {
streamCopierErrChan <- src.WriteStreamTo(w)
w.Close()
}()
}
type writeStreamRes struct {
errStream, errConn error
}
writeStreamErrChan := make(chan writeStreamRes, 1)
go func() {
var res writeStreamRes
res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType)
if w != nil {
w.CloseWithError(res.errStream)
}
writeStreamErrChan <- res
}()
writeRes := <-writeStreamErrChan
streamCopierErr := <-streamCopierErrChan
c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct?
if streamCopierErr != nil && streamCopierErr.IsReadError() {
return streamCopierErr // something on our side is bad
} else {
if writeRes.errStream != nil {
return writeRes.errStream
} else if writeRes.errConn != nil {
return writeRes.errConn
}
// TODO combined error?
return streamCopierErr
}
}
func (c *Conn) Close() error {
err := c.hc.Shutdown()
<-c.waitReadFramesDone
return err
}

View File

@ -0,0 +1,20 @@
package stream
import (
"fmt"
"os"
)
var debugEnabled bool = false
func init() {
if os.Getenv("ZREPL_RPC_DATACONN_STREAM_DEBUG") != "" {
debugEnabled = true
}
}
func debug(format string, args ...interface{}) {
if debugEnabled {
fmt.Fprintf(os.Stderr, "rpc/dataconn/stream: %s\n", fmt.Sprintf(format, args...))
}
}

View File

@ -0,0 +1,131 @@
package stream
import (
"bytes"
"context"
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
"github.com/zrepl/zrepl/util/socketpair"
)
func TestFrameTypesOk(t *testing.T) {
t.Logf("%v", End)
assert.True(t, heartbeatconn.IsPublicFrameType(End))
assert.True(t, heartbeatconn.IsPublicFrameType(StreamErrTrailer))
}
func TestStreamer(t *testing.T) {
anc, bnc, err := socketpair.SocketPair()
require.NoError(t, err)
hto := 1 * time.Hour
a := heartbeatconn.Wrap(anc, hto, hto)
b := heartbeatconn.Wrap(bnc, hto, hto)
log := logger.NewStderrDebugLogger()
ctx := WithLogger(context.Background(), log)
stype := uint32(0x23)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var buf bytes.Buffer
buf.Write(
bytes.Repeat([]byte{1, 2}, 1<<25),
)
writeStream(ctx, a, &buf, stype)
log.Debug("WriteStream returned")
a.Shutdown()
}()
go func() {
defer wg.Done()
var buf bytes.Buffer
ch := make(chan readFrameResult, 5)
wg.Add(1)
go func() {
defer wg.Done()
readFrames(ch, b)
}()
err := readStream(ch, b, &buf, stype)
log.WithField("errType", fmt.Sprintf("%T %v", err, err)).Debug("ReadStream returned")
assert.Nil(t, err)
expected := bytes.Repeat([]byte{1, 2}, 1<<25)
assert.True(t, bytes.Equal(expected, buf.Bytes()))
b.Shutdown()
}()
wg.Wait()
}
type errReader struct {
t *testing.T
readErr error
}
func (er errReader) Read(p []byte) (n int, err error) {
er.t.Logf("errReader.Read called")
return 0, er.readErr
}
func TestMultiFrameStreamErrTraileror(t *testing.T) {
anc, bnc, err := socketpair.SocketPair()
require.NoError(t, err)
hto := 1 * time.Hour
a := heartbeatconn.Wrap(anc, hto, hto)
b := heartbeatconn.Wrap(bnc, hto, hto)
log := logger.NewStderrDebugLogger()
ctx := WithLogger(context.Background(), log)
longErr := fmt.Errorf("an error that definitley spans more than one frame:\n%s", strings.Repeat("a\n", 1<<4))
stype := uint32(0x23)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
r := errReader{t, longErr}
writeStream(ctx, a, &r, stype)
a.Shutdown()
}()
go func() {
defer wg.Done()
defer b.Shutdown()
var buf bytes.Buffer
ch := make(chan readFrameResult, 5)
wg.Add(1)
go func() {
defer wg.Done()
readFrames(ch, b)
}()
err := readStream(ch, b, &buf, stype)
t.Logf("%s", err)
require.NotNil(t, err)
assert.True(t, buf.Len() == 0)
assert.Equal(t, err.Kind, ReadStreamErrorKindSource)
receivedErr := err.Err.Error()
expectedErr := longErr.Error()
assert.True(t, receivedErr == expectedErr) // builtin Equals is too slow
if receivedErr != expectedErr {
t.Logf("lengths: %v %v", len(receivedErr), len(expectedErr))
}
}()
wg.Wait()
}

View File

@ -0,0 +1,12 @@
# setup-specific
inventory
*.retry
# generated by gen_files.sh
files/*ssh_client_identity
files/*ssh_client_identity.pub
files/*.tls.*.key
files/*.tls.*.csr
files/*.tls.*.crt
files/wireevaluator

View File

@ -0,0 +1,15 @@
This directory contains very hacky test automation for wireevaluator based on nested Ansible playbooks.
* Copy `inventory.example` to `inventory`
* Adjust `inventory` IP addresses as needed
* Make sure there's an OpenSSH server running on the serve host
* Make sure there's no firewalling whatsoever between the hosts
* Run `GENKEYS=1 ./gen_files.sh` to re-generate self-signed TLS certs
* Run the following command, adjusting the `wireevaluator_repeat` value to the number of times you want to repeat each test
```
ansible-playbook -i inventory all.yml -e `wireevaluator_repeat=3`
```
Generally, things are fine if the playbook doesn't show any panics from wireevaluator.

View File

@ -0,0 +1,17 @@
- hosts: connect,serve
tasks:
- name: "run test"
include: internal_prepare_and_run_repeated.yml
wireevaluator_transport: "{{config.0}}"
wireevaluator_case: "{{config.1}}"
wireevaluator_repeat: "{{wireevaluator_repeat}}"
with_cartesian:
- [ tls, ssh, tcp ]
-
- closewrite_server
- closewrite_client
- readdeadline_server
- readdeadline_client
loop_control:
loop_var: config

View File

@ -0,0 +1,55 @@
#!/usr/bin/env bash
set -e
cd "$( dirname "${BASH_SOURCE[0]}")"
FILESDIR="$(pwd)"/files
echo "[INFO] compile binary"
pushd .. >/dev/null
go build -o $FILESDIR/wireevaluator
popd >/dev/null
if [ "$GENKEYS" == "" ]; then
echo "[INFO] GENKEYS environment variable not set, assumed to be valid"
exit 0
fi
echo "[INFO] gen ssh key"
ssh-keygen -f "$FILESDIR/wireevaluator.ssh_client_identity" -t ed25519
echo "[INFO] gen tls keys"
cakey="$FILESDIR/wireevaluator.tls.ca.key"
cacrt="$FILESDIR/wireevaluator.tls.ca.crt"
hostprefix="$FILESDIR/wireevaluator.tls"
openssl genrsa -out "$cakey" 4096
openssl req -x509 -new -nodes -key "$cakey" -sha256 -days 1 -out "$cacrt"
declare -a HOSTS
HOSTS+=("theserver")
HOSTS+=("theclient")
for host in "${HOSTS[@]}"; do
key="${hostprefix}.${host}.key"
csr="${hostprefix}.${host}.csr"
crt="${hostprefix}.${host}.crt"
openssl genrsa -out "$key" 2048
(
echo "."
echo "."
echo "."
echo "."
echo "."
echo $host
echo "."
echo "."
echo "."
echo "."
) | openssl req -new -key "$key" -out "$csr"
openssl x509 -req -in "$csr" -CA "$cacrt" -CAkey "$cakey" -CAcreateserial -out "$crt" -days 1 -sha256
done

View File

@ -0,0 +1,54 @@
---
- name: compile binary and any key files required
local_action: command ./gen_files.sh
- name: Kill test binary
shell: "killall -9 wireevaluator || true"
- name: Deploy new binary
copy:
src: "files/wireevaluator"
dest: "/opt/wireevaluator"
mode: 0755
- set_fact:
wireevaluator_connect_ip: "{{hostvars['connect'].ansible_host}}"
wireevaluator_serve_ip: "{{hostvars['serve'].ansible_host}}"
- name: Deploy config
template:
src: "templates/{{wireevaluator_transport}}.yml.j2"
dest: "/opt/wireevaluator.yml"
- name: Deploy client identity
copy:
src: "files/wireevaluator.{{item}}"
dest: "/opt/wireevaluator.{{item}}"
mode: 0400
with_items:
- ssh_client_identity
- ssh_client_identity.pub
- tls.ca.key
- tls.ca.crt
- tls.theserver.key
- tls.theserver.crt
- tls.theclient.key
- tls.theclient.crt
- name: Setup server ssh client identity access
when: inventory_hostname == "serve"
block:
- authorized_key:
user: root
state: present
key: "{{ lookup('file', 'files/wireevaluator.ssh_client_identity.pub') }}"
key_options: 'command="/opt/wireevaluator -mode stdinserver -config /opt/wireevaluator.yml client1"'
- file:
state: directory
mode: 0700
path: /tmp/wireevaluator_stdinserver
- name: repeated test
include: internal_run_test_prepared_single.yml
with_sequence: start=1 end={{wireevaluator_repeat}}

View File

@ -0,0 +1,38 @@
---
- debug:
msg: "run test transport={{wireevaluator_transport}} case={{wireevaluator_case}} repeatedly"
- name: Run Server
when: inventory_hostname == "serve"
command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode serve -testcase {{wireevaluator_case}}
register: spawn_servers
async: 60
poll: 0
- name: Run Client
when: inventory_hostname == "connect"
command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode connect -testcase {{wireevaluator_case}}
register: spawn_clients
async: 60
poll: 0
- name: Wait for server shutdown
when: inventory_hostname == "serve"
async_status:
jid: "{{ spawn_servers.ansible_job_id}}"
delay: 0.5
retries: 10
- name: Wait for client shutdown
when: inventory_hostname == "connect"
async_status:
jid: "{{ spawn_clients.ansible_job_id}}"
delay: 0.5
retries: 10
- name: Wait for connections to die (TIME_WAIT conns)
command: sleep 4
changed_when: false

View File

@ -0,0 +1,2 @@
connect ansible_user=root ansible_host=192.168.122.128 wireevaluator_mode="connect"
serve ansible_user=root ansible_host=192.168.122.129 wireevaluator_mode="serve"

View File

@ -0,0 +1,13 @@
connect:
type: ssh+stdinserver
host: {{wireevaluator_serve_ip}}
user: root
port: 22
identity_file: /opt/wireevaluator.ssh_client_identity
options: # optional, default [], `-o` arguments passed to ssh
- "Compression=yes"
serve:
type: stdinserver
client_identities:
- "client1"

View File

@ -0,0 +1,10 @@
connect:
type: tcp
address: "{{wireevaluator_serve_ip}}:8888"
serve:
type: tcp
listen: ":8888"
clients: {
"{{wireevaluator_connect_ip}}" : "client1"
}

View File

@ -0,0 +1,16 @@
connect:
type: tls
address: "{{wireevaluator_serve_ip}}:8888"
ca: "/opt/wireevaluator.tls.ca.crt"
cert: "/opt/wireevaluator.tls.theclient.crt"
key: "/opt/wireevaluator.tls.theclient.key"
server_cn: "theserver"
serve:
type: tls
listen: ":8888"
ca: "/opt/wireevaluator.tls.ca.crt"
cert: "/opt/wireevaluator.tls.theserver.crt"
key: "/opt/wireevaluator.tls.theserver.key"
client_cns:
- "theclient"

View File

@ -0,0 +1,111 @@
// a tool to test whether a given transport implements the timeoutconn.Wire interface
package main
import (
"context"
"flag"
"fmt"
"io/ioutil"
"os"
"path"
netssh "github.com/problame/go-netssh"
"github.com/zrepl/yaml-config"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport"
transportconfig "github.com/zrepl/zrepl/transport/fromconfig"
)
func noerror(err error) {
if err != nil {
panic(err)
}
}
type Config struct {
Connect config.ConnectEnum
Serve config.ServeEnum
}
var args struct {
mode string
configPath string
testCase string
}
var conf Config
type TestCase interface {
Client(wire transport.Wire)
Server(wire transport.Wire)
}
func main() {
flag.StringVar(&args.mode, "mode", "", "connect|serve")
flag.StringVar(&args.configPath, "config", "", "config file path")
flag.StringVar(&args.testCase, "testcase", "", "")
flag.Parse()
bytes, err := ioutil.ReadFile(args.configPath)
noerror(err)
err = yaml.UnmarshalStrict(bytes, &conf)
noerror(err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
global := &config.Global{
Serve: &config.GlobalServe{
StdinServer: &config.GlobalStdinServer{
SockDir: "/tmp/wireevaluator_stdinserver",
},
},
}
switch args.mode {
case "connect":
tc, err := getTestCase(args.testCase)
noerror(err)
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect)
noerror(err)
wire, err := connecter.Connect(ctx)
noerror(err)
tc.Client(wire)
case "serve":
tc, err := getTestCase(args.testCase)
noerror(err)
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve)
noerror(err)
l, err := lf()
noerror(err)
conn, err := l.Accept(ctx)
noerror(err)
tc.Server(conn)
case "stdinserver":
identity := flag.Arg(0)
unixaddr := path.Join(global.Serve.StdinServer.SockDir, identity)
err := netssh.Proxy(ctx, unixaddr)
if err == nil {
os.Exit(0)
}
panic(err)
default:
panic(fmt.Sprintf("unknown mode %q", args.mode))
}
}
func getTestCase(tcName string) (TestCase, error) {
switch tcName {
case "closewrite_server":
return &CloseWrite{mode: CloseWriteServerSide}, nil
case "closewrite_client":
return &CloseWrite{mode: CloseWriteClientSide}, nil
case "readdeadline_client":
return &Deadlines{mode: DeadlineModeClientTimeout}, nil
case "readdeadline_server":
return &Deadlines{mode: DeadlineModeServerTimeout}, nil
default:
return nil, fmt.Errorf("unknown test case %q", tcName)
}
}

View File

@ -0,0 +1,110 @@
package main
import (
"bytes"
"io"
"io/ioutil"
"log"
"github.com/zrepl/zrepl/transport"
)
type CloseWriteMode uint
const (
CloseWriteClientSide CloseWriteMode = 1 + iota
CloseWriteServerSide
)
type CloseWrite struct {
mode CloseWriteMode
}
// sent repeatedly
var closeWriteTestSendData = bytes.Repeat([]byte{0x23, 0x42}, 1<<24)
var closeWriteErrorMsg = []byte{0xb, 0xa, 0xd, 0xf, 0x0, 0x0, 0xd}
func (m CloseWrite) Client(wire transport.Wire) {
switch m.mode {
case CloseWriteClientSide:
m.receiver(wire)
case CloseWriteServerSide:
m.sender(wire)
default:
panic(m.mode)
}
}
func (m CloseWrite) Server(wire transport.Wire) {
switch m.mode {
case CloseWriteClientSide:
m.sender(wire)
case CloseWriteServerSide:
m.receiver(wire)
default:
panic(m.mode)
}
}
func (CloseWrite) sender(wire transport.Wire) {
defer func() {
closeErr := wire.Close()
log.Printf("closeErr=%T %s", closeErr, closeErr)
}()
type opResult struct {
err error
}
writeDone := make(chan struct{}, 1)
go func() {
close(writeDone)
for {
_, err := wire.Write(closeWriteTestSendData)
if err != nil {
return
}
}
}()
defer func() {
<-writeDone
}()
var respBuf bytes.Buffer
_, err := io.Copy(&respBuf, wire)
if err != nil {
log.Fatalf("should have received io.EOF, which is masked by io.Copy, got: %s", err)
}
if !bytes.Equal(respBuf.Bytes(), closeWriteErrorMsg) {
log.Fatalf("did not receive error message, got response with len %v:\n%v", respBuf.Len(), respBuf.Bytes())
}
}
func (CloseWrite) receiver(wire transport.Wire) {
// consume half the test data, then detect an error, send it and CloseWrite
r := io.LimitReader(wire, int64(5 * len(closeWriteTestSendData)/3))
_, err := io.Copy(ioutil.Discard, r)
noerror(err)
var errBuf bytes.Buffer
errBuf.Write(closeWriteErrorMsg)
_, err = io.Copy(wire, &errBuf)
noerror(err)
err = wire.CloseWrite()
noerror(err)
// drain wire, as documented in transport.Wire, this is the only way we know the client closed the conn
_, err = io.Copy(ioutil.Discard, wire)
if err != nil {
// io.Copy masks io.EOF to nil, and we expect io.EOF from the client's Close() call
log.Panicf("unexpected error returned from reading conn: %s", err)
}
closeErr := wire.Close()
log.Printf("closeErr=%T %s", closeErr, closeErr)
}

View File

@ -0,0 +1,138 @@
package main
import (
"bytes"
"fmt"
"io"
"log"
"net"
"time"
"github.com/zrepl/zrepl/transport"
)
type DeadlineMode uint
const (
DeadlineModeClientTimeout DeadlineMode = 1 + iota
DeadlineModeServerTimeout
)
type Deadlines struct {
mode DeadlineMode
}
func (d Deadlines) Client(wire transport.Wire) {
switch d.mode {
case DeadlineModeClientTimeout:
d.sleepThenSend(wire)
case DeadlineModeServerTimeout:
d.sendThenRead(wire)
default:
panic(d.mode)
}
}
func (d Deadlines) Server(wire transport.Wire) {
switch d.mode {
case DeadlineModeClientTimeout:
d.sendThenRead(wire)
case DeadlineModeServerTimeout:
d.sleepThenSend(wire)
default:
panic(d.mode)
}
}
var deadlinesTimeout = 1 * time.Second
func (d Deadlines) sleepThenSend(wire transport.Wire) {
defer wire.Close()
log.Print("sleepThenSend")
// exceed timeout of peer (do not respond to their hi msg)
time.Sleep(3 * deadlinesTimeout)
// expect that the client has hung up on us by now
err := d.sendMsg(wire, "hi")
log.Printf("err=%s", err)
log.Printf("err=%#v", err)
if err == nil {
log.Panic("no error")
}
if _, ok := err.(net.Error); !ok {
log.Panic("not a net error")
}
}
func (d Deadlines) sendThenRead(wire transport.Wire) {
log.Print("sendThenRead")
err := d.sendMsg(wire, "hi")
noerror(err)
err = wire.SetReadDeadline(time.Now().Add(deadlinesTimeout))
noerror(err)
m, err := d.recvMsg(wire)
log.Printf("m=%q", m)
log.Printf("err=%s", err)
log.Printf("err=%#v", err)
// close asap so that the peer get's a 'connection reset by peer' error or similar
closeErr := wire.Close()
if closeErr != nil {
panic(closeErr)
}
var neterr net.Error
var ok bool
if err == nil {
goto unexpErr // works for nil, too
}
neterr, ok = err.(net.Error)
if !ok {
log.Println("not a net error")
goto unexpErr
}
if !neterr.Timeout() {
log.Println("not a timeout")
}
return
unexpErr:
panic(fmt.Sprintf("sendThenRead: client should have hung up but got error %T %s", err, err))
}
const deadlinesMsgLen = 40
func (d Deadlines) sendMsg(wire transport.Wire, msg string) error {
if len(msg) > deadlinesMsgLen {
panic(len(msg))
}
var buf [deadlinesMsgLen]byte
copy(buf[:], []byte(msg))
n, err := wire.Write(buf[:])
if err != nil {
return err
}
if n != len(buf) {
panic("short write not allowed")
}
return nil
}
func (d Deadlines) recvMsg(wire transport.Wire) (string, error) {
var buf bytes.Buffer
r := io.LimitReader(wire, deadlinesMsgLen)
_, err := io.Copy(&buf, r)
if err != nil {
return "", err
}
return buf.String(), nil
}

View File

@ -0,0 +1,288 @@
// package timeoutconn wraps a Wire to provide idle timeouts
// based on Set{Read,Write}Deadline.
// Additionally, it exports abstractions for vectored I/O.
package timeoutconn
// NOTE
// Readv and Writev are not split-off into a separate package
// because we use raw syscalls, bypassing Conn's Read / Write methods.
import (
"errors"
"io"
"net"
"sync/atomic"
"syscall"
"time"
"unsafe"
)
type Wire interface {
net.Conn
// A call to CloseWrite indicates that no further Write calls will be made to Wire.
// The implementation must return an error in case of Write calls after CloseWrite.
// On the peer's side, after it read all data written to Wire prior to the call to
// CloseWrite on our side, the peer's Read calls must return io.EOF.
// CloseWrite must not affect the read-direction of Wire: specifically, the
// peer must continue to be able to send, and our side must continue be
// able to receive data over Wire.
//
// Note that CloseWrite may (and most likely will) return sooner than the
// peer having received all data written to Wire prior to CloseWrite.
// Note further that buffering happening in the network stacks on either side
// mandates an explicit acknowledgement from the peer that the connection may
// be fully shut down: If we call Close without such acknowledgement, any data
// from peer to us that was already in flight may cause connection resets to
// be sent from us to the peer via the specific transport protocol. Those
// resets (e.g. RST frames) may erase all connection context on the peer,
// including data in its receive buffers. Thus, those resets are in race with
// a) transmission of data written prior to CloseWrite and
// b) the peer application reading from those buffers.
//
// The WaitForPeerClose method can be used to wait for connection termination,
// iff the implementation supports it. If it does not, the only reliable way
// to wait for a peer to have read all data from Wire (until io.EOF), is to
// expect it to close the wire at that point as well, and to drain Wire until
// we also read io.EOF.
CloseWrite() error
// Wait for the peer to close the connection.
// No data that could otherwise be Read is lost as a consequence of this call.
// The use case for this API is abortive connection shutdown.
// To provide any value over draining Wire using io.Read, an implementation
// will likely use out-of-bounds messaging mechanisms.
// TODO WaitForPeerClose() (supported bool, err error)
}
type Conn struct {
Wire
renewDeadlinesDisabled int32
idleTimeout time.Duration
}
func Wrap(conn Wire, idleTimeout time.Duration) Conn {
return Conn{Wire: conn, idleTimeout: idleTimeout}
}
// DisableTimeouts disables the idle timeout behavior provided by this package.
// Existing deadlines are cleared iff the call is the first call to this method.
func (c *Conn) DisableTimeouts() error {
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) {
return c.SetDeadline(time.Time{})
}
return nil
}
func (c *Conn) renewReadDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
return nil
}
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
func (c *Conn) renewWriteDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
return nil
}
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
}
func (c Conn) Read(p []byte) (n int, err error) {
n = 0
err = nil
restart:
if err := c.renewReadDeadline(); err != nil {
return n, err
}
var nCurRead int
nCurRead, err = c.Wire.Read(p[n:len(p)])
n += nCurRead
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 {
err = nil
goto restart
}
return n, err
}
func (c Conn) Write(p []byte) (n int, err error) {
n = 0
restart:
if err := c.renewWriteDeadline(); err != nil {
return n, err
}
var nCurWrite int
nCurWrite, err = c.Wire.Write(p[n:len(p)])
n += nCurWrite
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
err = nil
goto restart
}
return n, err
}
// Writes the given buffers to Conn, following the sematincs of io.Copy,
// but is guaranteed to use the writev system call if the wrapped Wire
// support it.
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
n = 0
restart:
if err := c.renewWriteDeadline(); err != nil {
return n, err
}
var nCurWrite int64
nCurWrite, err = io.Copy(c.Wire, &bufs)
n += nCurWrite
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
err = nil
goto restart
}
return n, err
}
var SyscallConnNotSupported = errors.New("SyscallConn not supported")
// The interface that must be implemented for vectored I/O support.
// If the wrapped Wire does not implement it, a less efficient
// fallback implementation is used.
// Rest assured that Go's *net.TCPConn implements this interface.
type SyscallConner interface {
// The sentinel error value SyscallConnNotSupported can be returned
// if the support for SyscallConn depends on runtime conditions and
// that runtime condition is not met.
SyscallConn() (syscall.RawConn, error)
}
var _ SyscallConner = (*net.TCPConn)(nil)
func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
vecs = make([]syscall.Iovec, 0, len(buffers))
for i := range buffers {
totalLen += int64(len(buffers[i]))
if len(buffers[i]) == 0 {
continue
}
vecs = append(vecs, syscall.Iovec{
Base: &buffers[i][0],
Len: uint64(len(buffers[i])),
})
}
return totalLen, vecs
}
// Reads the given buffers full:
// Think of io.ReadvFull, but for net.Buffers + using the readv syscall.
//
// If the underlying Wire is not a SyscallConner, a fallback
// ipmlementation based on repeated Conn.Read invocations is used.
//
// If the connection returned io.EOF, the number of bytes up ritten until
// then + io.EOF is returned. This behavior is different to io.ReadFull
// which returns io.ErrUnexpectedEOF.
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
totalLen, iovecs := buildIovecs(buffers)
if debugReadvNoShortReadsAssertEnable {
defer debugReadvNoShortReadsAssert(totalLen, n, err)
}
scc, ok := c.Wire.(SyscallConner)
if !ok {
return c.readvFallback(buffers)
}
raw, err := scc.SyscallConn()
if err == SyscallConnNotSupported {
return c.readvFallback(buffers)
}
if err != nil {
return 0, err
}
n, err = c.readv(raw, iovecs)
return
}
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
buffers := [][]byte(nbuffers)
for i := range buffers {
curBuf := buffers[i]
inner:
for len(curBuf) > 0 {
if err := c.renewReadDeadline(); err != nil {
return n, err
}
var oneN int
oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING
curBuf = curBuf[oneN:]
n += int64(oneN)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 {
continue inner
}
return n, err
}
}
}
return n, nil
}
func (c Conn) readv(rawConn syscall.RawConn, iovecs []syscall.Iovec) (n int64, err error) {
for len(iovecs) > 0 {
if err := c.renewReadDeadline(); err != nil {
return n, err
}
oneN, oneErr := c.doOneReadv(rawConn, &iovecs)
n += oneN
if netErr, ok := oneErr.(net.Error); ok && netErr.Timeout() && oneN > 0 { // TODO likely not working
continue
} else if oneErr == nil && oneN > 0 {
continue
} else {
return n, oneErr
}
}
return n, nil
}
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
// iovecs, n and err must not be shadowed!
thisReadN, _, errno := syscall.Syscall(
syscall.SYS_READV,
fd,
uintptr(unsafe.Pointer(&(*iovecs)[0])),
uintptr(len(*iovecs)),
)
if thisReadN == ^uintptr(0) {
if errno == syscall.EAGAIN {
return false
}
err = syscall.Errno(errno)
return true
}
if int(thisReadN) < 0 {
panic("unexpected return value")
}
n += int64(thisReadN)
// shift iovecs forward
for left := int64(thisReadN); left > 0; {
curVecNewLength := int64((*iovecs)[0].Len) - left // TODO assert conversion
if curVecNewLength <= 0 {
left -= int64((*iovecs)[0].Len)
*iovecs = (*iovecs)[1:]
} else {
(*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left)))
(*iovecs)[0].Len = uint64(curVecNewLength)
break // inner
}
}
if thisReadN == 0 {
err = io.EOF
return true
}
return true
})
if rawReadErr != nil {
err = rawReadErr
}
return n, err
}

View File

@ -0,0 +1,19 @@
package timeoutconn
import (
"fmt"
"io"
)
const debugReadvNoShortReadsAssertEnable = false
func debugReadvNoShortReadsAssert(expectedLen, returnedLen int64, returnedErr error) {
readShort := expectedLen != returnedLen
if !readShort {
return
}
if returnedErr != io.EOF {
return
}
panic(fmt.Sprintf("ReadvFull short and error is not EOF%v\n", returnedErr))
}

View File

@ -0,0 +1,177 @@
package timeoutconn
import (
"bytes"
"io"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zrepl/zrepl/util/socketpair"
)
func TestReadTimeout(t *testing.T) {
a, b, err := socketpair.SocketPair()
require.NoError(t, err)
defer a.Close()
defer b.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var buf bytes.Buffer
buf.WriteString("tooktoolong")
time.Sleep(500 * time.Millisecond)
_, err := io.Copy(a, &buf)
require.NoError(t, err)
}()
go func() {
defer wg.Done()
conn := Wrap(b, 100*time.Millisecond)
buf := [4]byte{} // shorter than message put on wire
n, err := conn.Read(buf[:])
assert.Equal(t, 0, n)
assert.Error(t, err)
netErr, ok := err.(net.Error)
require.True(t, ok)
assert.True(t, netErr.Timeout())
}()
wg.Wait()
}
type writeBlockConn struct {
net.Conn
blockTime time.Duration
}
func (c writeBlockConn) Write(p []byte) (int, error) {
time.Sleep(c.blockTime)
return c.Conn.Write(p)
}
func (c writeBlockConn) CloseWrite() error {
return c.Conn.Close()
}
func TestWriteTimeout(t *testing.T) {
a, b, err := socketpair.SocketPair()
require.NoError(t, err)
defer a.Close()
defer b.Close()
var buf bytes.Buffer
buf.WriteString("message")
blockConn := writeBlockConn{a, 500 * time.Millisecond}
conn := Wrap(blockConn, 100*time.Millisecond)
n, err := conn.Write(buf.Bytes())
assert.Equal(t, 0, n)
assert.Error(t, err)
netErr, ok := err.(net.Error)
require.True(t, ok)
assert.True(t, netErr.Timeout())
}
func TestNoPartialReadsDueToDeadline(t *testing.T) {
a, b, err := socketpair.SocketPair()
require.NoError(t, err)
defer a.Close()
defer b.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
a.Write([]byte{1, 2, 3, 4, 5})
// sleep to provoke a partial read in the consumer goroutine
time.Sleep(50 * time.Millisecond)
a.Write([]byte{6, 7, 8, 9, 10})
}()
go func() {
defer wg.Done()
bc := Wrap(b, 100*time.Millisecond)
var buf bytes.Buffer
beginRead := time.Now()
// io.Copy will encounter a partial read, then wait ~50ms until the other 5 bytes are written
// It is still going to fail with deadline err because it expects EOF
n, err := io.Copy(&buf, bc)
readDuration := time.Now().Sub(beginRead)
t.Logf("read duration=%s", readDuration)
t.Logf("recv done n=%v err=%v", n, err)
t.Logf("buf=%v", buf.Bytes())
neterr, ok := err.(net.Error)
require.True(t, ok)
assert.True(t, neterr.Timeout())
assert.Equal(t, int64(10), n)
assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, buf.Bytes())
// 50ms for the second read, 100ms after that one for the deadline
// allow for some jitter
assert.True(t, readDuration > 140*time.Millisecond)
assert.True(t, readDuration < 160*time.Millisecond)
}()
wg.Wait()
}
type partialWriteMockConn struct {
net.Conn // to satisfy interface
buf bytes.Buffer
writeDuration time.Duration
returnAfterBytesWritten int
}
func newPartialWriteMockConn(writeDuration time.Duration, returnAfterBytesWritten int) *partialWriteMockConn {
return &partialWriteMockConn{
writeDuration: writeDuration,
returnAfterBytesWritten: returnAfterBytesWritten,
}
}
func (c *partialWriteMockConn) Write(p []byte) (int, error) {
time.Sleep(c.writeDuration)
consumeBytes := len(p)
if consumeBytes > c.returnAfterBytesWritten {
consumeBytes = c.returnAfterBytesWritten
}
n, err := c.buf.Write(p[0:consumeBytes])
if err != nil || n != consumeBytes {
panic("bytes.Buffer behaves unexpectedly")
}
return n, nil
}
func TestPartialWriteMockConn(t *testing.T) {
mc := newPartialWriteMockConn(100*time.Millisecond, 5)
buf := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
begin := time.Now()
n, err := mc.Write(buf[:])
duration := time.Now().Sub(begin)
assert.NoError(t, err)
assert.Equal(t, 5, n)
assert.True(t, duration > 100*time.Millisecond)
assert.True(t, duration < 150*time.Millisecond)
}
func TestNoPartialWritesDueToDeadline(t *testing.T) {
a, b, err := socketpair.SocketPair()
require.NoError(t, err)
defer a.Close()
defer b.Close()
var buf bytes.Buffer
buf.WriteString("message")
blockConn := writeBlockConn{a, 150 * time.Millisecond}
conn := Wrap(blockConn, 100*time.Millisecond)
n, err := conn.Write(buf.Bytes())
assert.Equal(t, 0, n)
assert.Error(t, err)
netErr, ok := err.(net.Error)
require.True(t, ok)
assert.True(t, netErr.Timeout())
}

View File

@ -0,0 +1,118 @@
// Package grpcclientidentity makes the client identity
// provided by github.com/zrepl/zrepl/daemon/transport/serve.{AuthenticatedListener,AuthConn}
// available to gRPC service handlers.
//
// This goal is achieved through the combination of custom gRPC transport credentials and two interceptors
// (i.e. middleware).
//
// For gRPC clients, the TransportCredentials + Dialer can be used to construct a gRPC client (grpc.ClientConn)
// that uses a github.com/zrepl/zrepl/daemon/transport/connect.Connecter to connect to a server.
//
// The adaptors exposed by this package must be used together, and panic if they are not.
// See package grpchelper for a more restrictive but safe example on how the adaptors should be composed.
package grpcclientidentity
import (
"context"
"fmt"
"net"
"time"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/transport"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)
type Logger = logger.Logger
type GRPCDialFunction = func(string, time.Duration) (net.Conn, error)
func NewDialer(logger Logger, connecter transport.Connecter) GRPCDialFunction {
return func(s string, duration time.Duration) (conn net.Conn, e error) {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
nc, err := connecter.Connect(ctx)
// TODO find better place (callback from gRPC?) where to log errors
// we want the users to know, though
if err != nil {
logger.WithError(err).Error("cannot connect")
}
return nc, err
}
}
type authConnAuthType struct {
clientIdentity string
}
func (authConnAuthType) AuthType() string {
return "AuthConn"
}
type connecterAuthType struct{}
func (connecterAuthType) AuthType() string {
return "connecter"
}
type transportCredentials struct {
logger Logger
}
// Use on both sides as ServerOption or ClientOption.
func NewTransportCredentials(log Logger) credentials.TransportCredentials {
if log == nil {
log = logger.NewNullLogger()
}
return &transportCredentials{log}
}
func (c *transportCredentials) ClientHandshake(ctx context.Context, s string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.logger.WithField("url", s).WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ClientHandshake")
// do nothing, client credential is only for WithInsecure warning to go away
// the authentication is done by the connecter
return rawConn, &connecterAuthType{}, nil
}
func (c *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.logger.WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ServerHandshake")
authConn, ok := rawConn.(*transport.AuthConn)
if !ok {
panic(fmt.Sprintf("NewTransportCredentials must be used with a listener that returns *transport.AuthConn, got %T", rawConn))
}
return rawConn, &authConnAuthType{authConn.ClientIdentity()}, nil
}
func (*transportCredentials) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{} // TODO
}
func (t *transportCredentials) Clone() credentials.TransportCredentials {
var x = *t
return &x
}
func (*transportCredentials) OverrideServerName(string) error {
panic("not implemented")
}
func NewInterceptors(logger Logger, clientIdentityKey interface{}) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) {
unary = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
logger.WithField("fullMethod", info.FullMethod).Debug("request")
p, ok := peer.FromContext(ctx)
if !ok {
panic("peer.FromContext expected to return a peer in grpc.UnaryServerInterceptor")
}
logger.WithField("peer", fmt.Sprintf("%v", p)).Debug("peer")
a, ok := p.AuthInfo.(*authConnAuthType)
if !ok {
panic(fmt.Sprintf("NewInterceptors must be used in combination with grpc.NewTransportCredentials, but got auth type %T", p.AuthInfo))
}
ctx = context.WithValue(ctx, clientIdentityKey, a.clientIdentity)
return handler(ctx, req)
}
stream = nil
return
}

View File

@ -0,0 +1,16 @@
syntax = "proto3";
package pdu;
service Greeter {
rpc Greet(GreetRequest) returns (GreetResponse) {}
}
message GreetRequest {
string name = 1;
}
message GreetResponse {
string msg = 1;
}

View File

@ -0,0 +1,107 @@
// This package demonstrates how the grpcclientidentity package can be used
// to set up a gRPC greeter service.
package main
import (
"context"
"flag"
"fmt"
"os"
"time"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/grpcclientidentity/example/pdu"
"github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper"
"github.com/zrepl/zrepl/transport/tcp"
)
var args struct {
mode string
}
var log = logger.NewStderrDebugLogger()
func main() {
flag.StringVar(&args.mode, "mode", "", "client|server")
flag.Parse()
switch args.mode {
case "client":
client()
case "server":
server()
default:
log.Printf("unknown mode %q")
os.Exit(1)
}
}
func onErr(err error, format string, args ...interface{}) {
log.WithError(err).Error(fmt.Sprintf("%s: %s", fmt.Sprintf(format, args...), err))
os.Exit(1)
}
func client() {
cn, err := tcp.TCPConnecterFromConfig(&config.TCPConnect{
ConnectCommon: config.ConnectCommon{
Type: "tcp",
},
Address: "127.0.0.1:8080",
DialTimeout: 10 * time.Second,
})
if err != nil {
onErr(err, "build connecter error")
}
clientConn := grpchelper.ClientConn(cn, log)
defer clientConn.Close()
// normal usage from here on
client := pdu.NewGreeterClient(clientConn)
resp, err := client.Greet(context.Background(), &pdu.GreetRequest{Name: "somethingimadeup"})
if err != nil {
onErr(err, "RPC error")
}
fmt.Printf("got response:\n\t%s\n", resp.GetMsg())
}
const clientIdentityKey = "clientIdentity"
func server() {
authListenerFactory, err := tcp.TCPListenerFactoryFromConfig(nil, &config.TCPServe{
ServeCommon: config.ServeCommon{
Type: "tcp",
},
Listen: "127.0.0.1:8080",
Clients: map[string]string{
"127.0.0.1": "localclient",
"::1": "localclient",
},
})
if err != nil {
onErr(err, "cannot build listener factory")
}
log := logger.NewStderrDebugLogger()
srv, serve, err := grpchelper.NewServer(authListenerFactory, clientIdentityKey, log)
svc := &greeter{"hello "}
pdu.RegisterGreeterServer(srv, svc)
if err := serve(); err != nil {
onErr(err, "error serving")
}
}
type greeter struct {
prepend string
}
func (g *greeter) Greet(ctx context.Context, r *pdu.GreetRequest) (*pdu.GreetResponse, error) {
ci, _ := ctx.Value(clientIdentityKey).(string)
log.WithField("clientIdentity", ci).Info("Greet() request") // show that we got the client identity
return &pdu.GreetResponse{Msg: fmt.Sprintf("%s%s (clientIdentity=%q)", g.prepend, r.GetName(), ci)}, nil
}

View File

@ -0,0 +1,193 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpcauth.proto
package pdu
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type GreetRequest struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *GreetRequest) Reset() { *m = GreetRequest{} }
func (m *GreetRequest) String() string { return proto.CompactTextString(m) }
func (*GreetRequest) ProtoMessage() {}
func (*GreetRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_1dfba7be0cf69353, []int{0}
}
func (m *GreetRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GreetRequest.Unmarshal(m, b)
}
func (m *GreetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GreetRequest.Marshal(b, m, deterministic)
}
func (m *GreetRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_GreetRequest.Merge(m, src)
}
func (m *GreetRequest) XXX_Size() int {
return xxx_messageInfo_GreetRequest.Size(m)
}
func (m *GreetRequest) XXX_DiscardUnknown() {
xxx_messageInfo_GreetRequest.DiscardUnknown(m)
}
var xxx_messageInfo_GreetRequest proto.InternalMessageInfo
func (m *GreetRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
type GreetResponse struct {
Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *GreetResponse) Reset() { *m = GreetResponse{} }
func (m *GreetResponse) String() string { return proto.CompactTextString(m) }
func (*GreetResponse) ProtoMessage() {}
func (*GreetResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_1dfba7be0cf69353, []int{1}
}
func (m *GreetResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GreetResponse.Unmarshal(m, b)
}
func (m *GreetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GreetResponse.Marshal(b, m, deterministic)
}
func (m *GreetResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_GreetResponse.Merge(m, src)
}
func (m *GreetResponse) XXX_Size() int {
return xxx_messageInfo_GreetResponse.Size(m)
}
func (m *GreetResponse) XXX_DiscardUnknown() {
xxx_messageInfo_GreetResponse.DiscardUnknown(m)
}
var xxx_messageInfo_GreetResponse proto.InternalMessageInfo
func (m *GreetResponse) GetMsg() string {
if m != nil {
return m.Msg
}
return ""
}
func init() {
proto.RegisterType((*GreetRequest)(nil), "pdu.GreetRequest")
proto.RegisterType((*GreetResponse)(nil), "pdu.GreetResponse")
}
func init() { proto.RegisterFile("grpcauth.proto", fileDescriptor_1dfba7be0cf69353) }
var fileDescriptor_1dfba7be0cf69353 = []byte{
// 137 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x2a, 0x48,
0x4e, 0x2c, 0x2d, 0xc9, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2e, 0x48, 0x29, 0x55,
0x52, 0xe2, 0xe2, 0x71, 0x2f, 0x4a, 0x4d, 0x2d, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11,
0x12, 0xe2, 0x62, 0xc9, 0x4b, 0xcc, 0x4d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3,
0x95, 0x14, 0xb9, 0x78, 0xa1, 0x6a, 0x8a, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0x04, 0xb8, 0x98,
0x73, 0x8b, 0xd3, 0xa1, 0x6a, 0x40, 0x4c, 0x23, 0x6b, 0x2e, 0x76, 0xb0, 0x92, 0xd4, 0x22, 0x21,
0x03, 0x2e, 0x56, 0x30, 0x53, 0x48, 0x50, 0xaf, 0x20, 0xa5, 0x54, 0x0f, 0xd9, 0x74, 0x29, 0x21,
0x64, 0x21, 0x88, 0x61, 0x4a, 0x0c, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff,
0xff, 0xa8, 0x53, 0x2f, 0x4c, 0xa1, 0x00, 0x00, 0x00,
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// GreeterClient is the client API for Greeter service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type GreeterClient interface {
Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error)
}
type greeterClient struct {
cc *grpc.ClientConn
}
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
return &greeterClient{cc}
}
func (c *greeterClient) Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) {
out := new(GreetResponse)
err := c.cc.Invoke(ctx, "/pdu.Greeter/Greet", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// GreeterServer is the server API for Greeter service.
type GreeterServer interface {
Greet(context.Context, *GreetRequest) (*GreetResponse, error)
}
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
s.RegisterService(&_Greeter_serviceDesc, srv)
}
func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GreetRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).Greet(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/pdu.Greeter/Greet",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).Greet(ctx, req.(*GreetRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Greeter_serviceDesc = grpc.ServiceDesc{
ServiceName: "pdu.Greeter",
HandlerType: (*GreeterServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Greet",
Handler: _Greeter_Greet_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "grpcauth.proto",
}

View File

@ -0,0 +1,76 @@
// Package grpchelper wraps the adaptors implemented by package grpcclientidentity into a less flexible API
// which, however, ensures that the individual adaptor primitive's expectations are met and hence do not panic.
package grpchelper
import (
"context"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/grpcclientidentity"
"github.com/zrepl/zrepl/rpc/netadaptor"
"github.com/zrepl/zrepl/transport"
)
// The following constants are relevant for interoperability.
// We use the same values for client & server, because zrepl is more
// symmetrical ("one source, one sink") instead of the typical
// gRPC scenario ("many clients, single server")
const (
StartKeepalivesAfterInactivityDuration = 5 * time.Second
KeepalivePeerTimeout = 10 * time.Second
)
type Logger = logger.Logger
// ClientConn is an easy-to-use wrapper around the Dialer and TransportCredentials interface
// to produce a grpc.ClientConn
func ClientConn(cn transport.Connecter, log Logger) *grpc.ClientConn {
ka := grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: StartKeepalivesAfterInactivityDuration,
Timeout: KeepalivePeerTimeout,
PermitWithoutStream: true,
})
dialerOption := grpc.WithDialer(grpcclientidentity.NewDialer(log, cn))
cred := grpc.WithTransportCredentials(grpcclientidentity.NewTransportCredentials(log))
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, "doesn't matter done by dialer", dialerOption, cred, ka)
if err != nil {
log.WithError(err).Error("cannot create gRPC client conn (non-blocking)")
// It's ok to panic here: the we call grpc.DialContext without the
// (grpc.WithBlock) dial option, and at the time of writing, the grpc
// docs state that no connection attempt is made in that case.
// Hence, any error that occurs is due to DialOptions or similar,
// and thus indicative of an implementation error.
panic(err)
}
return cc
}
// NewServer is a convenience interface around the TransportCredentials and Interceptors interface.
func NewServer(authListenerFactory transport.AuthenticatedListenerFactory, clientIdentityKey interface{}, logger grpcclientidentity.Logger) (srv *grpc.Server, serve func() error, err error) {
ka := grpc.KeepaliveParams(keepalive.ServerParameters{
Time: StartKeepalivesAfterInactivityDuration,
Timeout: KeepalivePeerTimeout,
})
tcs := grpcclientidentity.NewTransportCredentials(logger)
unary, stream := grpcclientidentity.NewInterceptors(logger, clientIdentityKey)
srv = grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream), ka)
serve = func() error {
l, err := authListenerFactory()
if err != nil {
return err
}
if err := srv.Serve(netadaptor.New(l, logger)); err != nil {
return err
}
return nil
}
return srv, serve, nil
}

View File

@ -0,0 +1,102 @@
// Package netadaptor implements an adaptor from
// transport.AuthenticatedListener to net.Listener.
//
// In contrast to transport.AuthenticatedListener,
// net.Listener is commonly expected (e.g. by net/http.Server.Serve),
// to return errors that fulfill the Temporary interface:
// interface Temporary { Temporary() bool }
// Common behavior of packages consuming net.Listener is to return
// from the serve-loop if an error returned by Accept is not Temporary,
// i.e., does not implement the interface or is !net.Error.Temporary().
//
// The zrepl transport infrastructure was written with the
// idea that Accept() may return any kind of error, and the consumer
// would just log the error and continue calling Accept().
// We have to adapt these listeners' behavior to the expectations
// of net/http.Server.
//
// Hence, Listener does not return an error at all but blocks the
// caller of Accept() until we get a (successfully authenticated)
// connection without errors from the transport.
// Accept errors returned from the transport are logged as errors
// to the logger passed on initialization.
package netadaptor
import (
"context"
"fmt"
"github.com/zrepl/zrepl/logger"
"net"
"github.com/zrepl/zrepl/transport"
)
type Logger = logger.Logger
type acceptReq struct {
callback chan net.Conn
}
type Listener struct {
al transport.AuthenticatedListener
logger Logger
accepts chan acceptReq
stop chan struct{}
}
// Consume the given authListener and wrap it as a *Listener, which implements net.Listener.
// The returned net.Listener must be closed.
// The wrapped authListener is closed when the returned net.Listener is closed.
func New(authListener transport.AuthenticatedListener, l Logger) *Listener {
if l == nil {
l = logger.NewNullLogger()
}
a := &Listener{
al: authListener,
logger: l,
accepts: make(chan acceptReq),
stop: make(chan struct{}),
}
go a.handleAccept()
return a
}
// The returned net.Conn is guaranteed to be *transport.AuthConn, i.e., the type of connection
// returned by the wrapped transport.AuthenticatedListener.
func (a Listener) Accept() (net.Conn, error) {
req := acceptReq{make(chan net.Conn, 1)}
a.accepts <- req
conn := <-req.callback
return conn, nil
}
func (a Listener) handleAccept() {
for {
select {
case <-a.stop:
a.logger.Debug("handleAccept stop accepting")
return
case req := <-a.accepts:
for {
a.logger.Debug("accept authListener")
authConn, err := a.al.Accept(context.Background())
if err != nil {
a.logger.WithError(err).Error("accept error")
continue
}
a.logger.WithField("type", fmt.Sprintf("%T", authConn)).
Debug("accept complete")
req.callback <- authConn
break
}
}
}
}
func (a Listener) Addr() net.Addr {
return a.al.Addr()
}
func (a Listener) Close() error {
close(a.stop)
return a.al.Close()
}

106
rpc/rpc_client.go Normal file
View File

@ -0,0 +1,106 @@
package rpc
import (
"context"
"net"
"time"
"google.golang.org/grpc"
"github.com/zrepl/zrepl/replication"
"github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/rpc/dataconn"
"github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper"
"github.com/zrepl/zrepl/rpc/versionhandshake"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/zfs"
)
// Client implements the active side of a replication setup.
// It satisfies the Endpoint, Sender and Receiver interface defined by package replication.
type Client struct {
dataClient *dataconn.Client
controlClient pdu.ReplicationClient // this the grpc client instance, see constructor
controlConn *grpc.ClientConn
loggers Loggers
}
var _ replication.Endpoint = &Client{}
var _ replication.Sender = &Client{}
var _ replication.Receiver = &Client{}
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
// config must be validated, NewClient will panic if it is not valid
func NewClient(cn transport.Connecter, loggers Loggers) *Client {
cn = versionhandshake.Connecter(cn, envconst.Duration("ZREPL_RPC_CLIENT_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second))
muxedConnecter := mux(cn)
c := &Client{
loggers: loggers,
}
grpcConn := grpchelper.ClientConn(muxedConnecter.control, loggers.Control)
go func() {
for {
state := grpcConn.GetState()
loggers.General.WithField("grpc_state", state.String()).Debug("grpc state change")
grpcConn.WaitForStateChange(context.TODO(), state)
}
}()
c.controlClient = pdu.NewReplicationClient(grpcConn)
c.controlConn = grpcConn
c.dataClient = dataconn.NewClient(muxedConnecter.data, loggers.Data)
return c
}
func (c *Client) Close() {
if err := c.controlConn.Close(); err != nil {
c.loggers.General.WithError(err).Error("cannot cloe control connection")
}
// TODO c.dataClient should have Close()
}
// callers must ensure that the returned io.ReadCloser is closed
// TODO expose dataClient interface to the outside world
func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
// TODO the returned sendStream may return a read error created by the remote side
res, streamCopier, err := c.dataClient.ReqSend(ctx, r)
if err != nil {
return nil, nil, nil
}
if streamCopier == nil {
return res, nil, nil
}
return res, streamCopier, nil
}
func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
return c.dataClient.ReqRecv(ctx, req, streamCopier)
}
func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
return c.controlClient.ListFilesystems(ctx, in)
}
func (c *Client) ListFilesystemVersions(ctx context.Context, in *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
return c.controlClient.ListFilesystemVersions(ctx, in)
}
func (c *Client) DestroySnapshots(ctx context.Context, in *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
return c.controlClient.DestroySnapshots(ctx, in)
}
func (c *Client) ReplicationCursor(ctx context.Context, in *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
return c.controlClient.ReplicationCursor(ctx, in)
}
func (c *Client) ResetConnectBackoff() {
c.controlConn.ResetConnectBackoff()
}

20
rpc/rpc_debug.go Normal file
View File

@ -0,0 +1,20 @@
package rpc
import (
"fmt"
"os"
)
var debugEnabled bool = false
func init() {
if os.Getenv("ZREPL_RPC_DEBUG") != "" {
debugEnabled = true
}
}
func debug(format string, args ...interface{}) {
if debugEnabled {
fmt.Fprintf(os.Stderr, "rpc: %s\n", fmt.Sprintf(format, args...))
}
}

118
rpc/rpc_doc.go Normal file
View File

@ -0,0 +1,118 @@
// Package rpc implements zrepl daemon-to-daemon RPC protocol
// on top of a transport provided by package transport.
// The zrepl documentation refers to the client as the
// `active side` and to the server as the `passive side`.
//
// Design Considerations
//
// zrepl has non-standard requirements to remote procedure calls (RPC):
// whereas the coordination of replication (the planning phase) mostly
// consists of regular unary calls, the actual data transfer requires
// a high-throughput, low-overhead solution.
//
// Specifically, the requirements for data transfer is to perform
// a copy of an io.Reader over the wire, such that an io.EOF of the original
// reader corresponds to an io.EOF on the receiving side.
// If any other error occurs while reading from the original io.Reader
// on the sender, the receiver should receive the contents of that error
// in some form (usually as a trailer message)
// A common implementation technique for above data transfer is chunking,
// for example in HTTP:
// https://tools.ietf.org/html/rfc2616#section-3.6.1
//
// With regards to regular RPCs, gRPC is a popular implementation based
// on protocol buffers and thus code generation.
// gRPC also supports bi-directional streaming RPC, and it is possible
// to implement chunked transfer through the use of streams.
//
// For zrepl however, both HTTP and manually implemented chunked transfer
// using gRPC were found to have significant CPU overhead at transfer
// speeds to be expected even with hobbyist users.
//
// However, it is nice to piggyback on the code generation provided
// by protobuf / gRPC, in particular since the majority of call types
// are regular unary RPCs for which the higher overhead of gRPC is acceptable.
//
// Hence, this package attempts to combine the best of both worlds:
//
// GRPC for Coordination and Dataconn for Bulk Data Transfer
//
// This package's Client uses its transport.Connecter to maintain
// separate control and data connections to the Server.
// The control connection is used by an instance of pdu.ReplicationClient
// whose 'regular' unary RPC calls are re-exported.
// The data connection is used by an instance of dataconn.Client and
// is used for bulk data transfer, namely `Send` and `Receive`.
//
// The following ASCII diagram gives an overview of how the individual
// building blocks are glued together:
//
// +------------+
// | rpc.Client |
// +------------+
// | |
// +--------+ +------------+
// | |
// +---------v-----------+ +--------v------+
// |pdu.ReplicationClient| |dataconn.Client|
// +---------------------+ +--------v------+
// | label: label: |
// | zrepl_control zrepl_data |
// +--------+ +------------+
// | |
// +--v---------v---+
// | transportmux |
// +-------+--------+
// | uses
// +-------v--------+
// |versionhandshake|
// +-------+--------+
// | uses
// +------v------+
// | transport |
// +------+------+
// |
// NETWORK
// |
// +------+------+
// | transport |
// +------^------+
// | uses
// +-------+--------+
// |versionhandshake|
// +-------^--------+
// | uses
// +-------+--------+
// | transportmux |
// +--^--------^----+
// | |
// +--------+ --------------+ ---
// | | |
// | label: label: | |
// | zrepl_control zrepl_data | |
// +-----+----+ +-----------+---+ |
// |netadaptor| |dataconn.Server| | rpc.Server
// | + | +------+--------+ |
// |grpcclient| | |
// |identity | | |
// +-----+----+ | |
// | | |
// +---------v-----------+ | |
// |pdu.ReplicationServer| | |
// +---------+-----------+ | |
// | | ---
// +----------+ +------------+
// | |
// +---v--v-----+
// | Handler |
// +------------+
// (usually endpoint.{Sender,Receiver})
//
//
package rpc
// edit trick for the ASCII art above:
// - remove the comments //
// - vim: set virtualedit+=all
// - vim: set ft=text

34
rpc/rpc_logging.go Normal file
View File

@ -0,0 +1,34 @@
package rpc
import (
"context"
"github.com/zrepl/zrepl/logger"
)
type Logger = logger.Logger
type contextKey int
const (
contextKeyLoggers contextKey = iota
contextKeyGeneralLogger
contextKeyControlLogger
contextKeyDataLogger
)
/// All fields must be non-nil
type Loggers struct {
General Logger
Control Logger
Data Logger
}
func WithLoggers(ctx context.Context, loggers Loggers) context.Context {
ctx = context.WithValue(ctx, contextKeyLoggers, loggers)
return ctx
}
func GetLoggersOrPanic(ctx context.Context) Loggers {
return ctx.Value(contextKeyLoggers).(Loggers)
}

57
rpc/rpc_mux.go Normal file
View File

@ -0,0 +1,57 @@
package rpc
import (
"context"
"time"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/rpc/transportmux"
"github.com/zrepl/zrepl/util/envconst"
)
type demuxedListener struct {
control, data transport.AuthenticatedListener
}
const (
transportmuxLabelControl string = "zrepl_control"
transportmuxLabelData string = "zrepl_data"
)
func demux(serveCtx context.Context, listener transport.AuthenticatedListener) demuxedListener {
listeners, err := transportmux.Demux(
serveCtx, listener,
[]string{transportmuxLabelControl, transportmuxLabelData},
envconst.Duration("ZREPL_TRANSPORT_DEMUX_TIMEOUT", 10*time.Second),
)
if err != nil {
// transportmux API guarantees that the returned error can only be due
// to invalid API usage (i.e. labels too long)
panic(err)
}
return demuxedListener{
control: listeners[transportmuxLabelControl],
data: listeners[transportmuxLabelData],
}
}
type muxedConnecter struct {
control, data transport.Connecter
}
func mux(rawConnecter transport.Connecter) muxedConnecter {
muxedConnecters, err := transportmux.MuxConnecter(
rawConnecter,
[]string{transportmuxLabelControl, transportmuxLabelData},
envconst.Duration("ZREPL_TRANSPORT_MUX_TIMEOUT", 10*time.Second),
)
if err != nil {
// transportmux API guarantees that the returned error can only be due
// to invalid API usage (i.e. labels too long)
panic(err)
}
return muxedConnecter{
control: muxedConnecters[transportmuxLabelControl],
data: muxedConnecters[transportmuxLabelData],
}
}

119
rpc/rpc_server.go Normal file
View File

@ -0,0 +1,119 @@
package rpc
import (
"context"
"time"
"google.golang.org/grpc"
"github.com/zrepl/zrepl/endpoint"
"github.com/zrepl/zrepl/replication/pdu"
"github.com/zrepl/zrepl/rpc/dataconn"
"github.com/zrepl/zrepl/rpc/grpcclientidentity"
"github.com/zrepl/zrepl/rpc/netadaptor"
"github.com/zrepl/zrepl/rpc/versionhandshake"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/util/envconst"
)
type Handler interface {
pdu.ReplicationServer
dataconn.Handler
}
type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error)
// Server abstracts the accept and request routing infrastructure for the
// passive side of a replication setup.
type Server struct {
logger Logger
handler Handler
controlServer *grpc.Server
controlServerServe serveFunc
dataServer *dataconn.Server
dataServerServe serveFunc
}
type serverContextKey int
type HandlerContextInterceptor func(ctx context.Context) context.Context
// config must be valid (use its Validate function).
func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server {
// setup control server
tcs := grpcclientidentity.NewTransportCredentials(loggers.Control) // TODO different subsystem for log
unary, stream := grpcclientidentity.NewInterceptors(loggers.Control, endpoint.ClientIdentityKey)
controlServer := grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream))
pdu.RegisterReplicationServer(controlServer, handler)
controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) {
// give time for graceful stop until deadline expires, then hard stop
go func() {
<-ctx.Done()
if dl, ok := ctx.Deadline(); ok {
go time.AfterFunc(dl.Sub(dl), controlServer.Stop)
}
loggers.Control.Debug("shutting down control server")
controlServer.GracefulStop()
}()
errOut <- controlServer.Serve(netadaptor.New(controlListener, loggers.Control))
}
// setup data server
dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) {
ci := wire.ClientIdentity()
ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci)
if ctxInterceptor != nil {
ctx = ctxInterceptor(ctx) // SHADOWING
}
return ctx, wire
}
dataServer := dataconn.NewServer(dataServerClientIdentitySetter, loggers.Data, handler)
dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) {
dataServer.Serve(ctx, dataListener)
errOut <- nil // TODO bad design of dataServer?
}
server := &Server{
logger: loggers.General,
handler: handler,
controlServer: controlServer,
controlServerServe: controlServerServe,
dataServer: dataServer,
dataServerServe: dataServerServe,
}
return server
}
// The context is used for cancellation only.
// Serve never returns an error, it logs them to the Server's logger.
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
ctx, cancel := context.WithCancel(ctx)
l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second))
// it is important that demux's context is cancelled,
// it has background goroutines attached
demuxListener := demux(ctx, l)
serveErrors := make(chan error, 2)
go s.controlServerServe(ctx, demuxListener.control, serveErrors)
go s.dataServerServe(ctx, demuxListener.data, serveErrors)
select {
case serveErr := <-serveErrors:
s.logger.WithError(serveErr).Error("serve error")
s.logger.Debug("wait for other server to shut down")
cancel()
secondServeErr := <-serveErrors
s.logger.WithError(secondServeErr).Error("serve error")
case <-ctx.Done():
s.logger.Debug("context cancelled, wait for control and data servers")
cancel()
for i := 0; i < 2; i++ {
<-serveErrors
}
s.logger.Debug("control and data server shut down, returning from Serve")
}
}

View File

@ -0,0 +1,205 @@
// Package transportmux wraps a transport.{Connecter,AuthenticatedListener}
// to distinguish different connection types based on a label
// sent from client to server on connection establishment.
//
// Labels are plain text and fixed length.
package transportmux
import (
"context"
"io"
"net"
"time"
"fmt"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/transport"
)
type contextKey int
const (
contextKeyLog contextKey = 1 + iota
)
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func getLog(ctx context.Context) Logger {
if l, ok := ctx.Value(contextKeyLog).(Logger); ok {
return l
}
return logger.NewNullLogger()
}
type acceptRes struct {
conn *transport.AuthConn
err error
}
type demuxListener struct {
conns chan acceptRes
}
func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
res := <-l.conns
return res.conn, res.err
}
type demuxAddr struct {}
func (demuxAddr) Network() string { return "demux" }
func (demuxAddr) String() string { return "demux" }
func (l *demuxListener) Addr() net.Addr {
return demuxAddr{}
}
func (l *demuxListener) Close() error { return nil } // TODO
// Exact length of a label in bytes (0-byte padded if it is shorter).
// This is a protocol constant, changing it breaks the wire protocol.
const LabelLen = 64
func padLabel(out []byte, label string) (error) {
if len(label) > LabelLen {
return fmt.Errorf("label %q exceeds max length (is %d, max %d)", label, len(label), LabelLen)
}
if len(out) != LabelLen {
panic(fmt.Sprintf("implementation error: %d", out))
}
labelBytes := []byte(label)
copy(out[:], labelBytes)
return nil
}
func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, labels []string, timeout time.Duration) (map[string]transport.AuthenticatedListener, error) {
padded := make(map[[64]byte]*demuxListener, len(labels))
ret := make(map[string]transport.AuthenticatedListener, len(labels))
for _, label := range labels {
var labelPadded [LabelLen]byte
err := padLabel(labelPadded[:], label)
if err != nil {
return nil, err
}
if _, ok := padded[labelPadded]; ok {
return nil, fmt.Errorf("duplicate label %q", label)
}
dl := &demuxListener{make(chan acceptRes)}
padded[labelPadded] = dl
ret[label] = dl
}
// invariant: padded contains same-length, non-duplicate labels
go func() {
<-ctx.Done()
getLog(ctx).Debug("context cancelled, closing listener")
if err := rawListener.Close(); err != nil {
getLog(ctx).WithError(err).Error("error closing listener")
}
}()
go func() {
for {
rawConn, err := rawListener.Accept(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
getLog(ctx).WithError(err).Error("accept error")
continue
}
closeConn := func() {
if err := rawConn.Close(); err != nil {
getLog(ctx).WithError(err).Error("cannot close conn")
}
}
if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil {
getLog(ctx).WithError(err).Error("SetDeadline failed")
closeConn()
continue
}
var labelBuf [LabelLen]byte
if _, err := io.ReadFull(rawConn, labelBuf[:]); err != nil {
getLog(ctx).WithError(err).Error("error reading label")
closeConn()
continue
}
demuxListener, ok := padded[labelBuf]
if !ok {
getLog(ctx).WithError(err).
WithField("client_label", fmt.Sprintf("%q", labelBuf)).
Error("unknown client label")
closeConn()
continue
}
rawConn.SetDeadline(time.Time{})
// blocking is intentional
demuxListener.conns <- acceptRes{conn: rawConn, err: nil}
}
}()
return ret, nil
}
type labeledConnecter struct {
label []byte
transport.Connecter
}
func (c labeledConnecter) Connect(ctx context.Context) (transport.Wire, error) {
conn, err := c.Connecter.Connect(ctx)
if err != nil {
return nil, err
}
closeConn := func(why error) {
getLog(ctx).WithField("reason", why.Error()).Debug("closing connection")
if err := conn.Close(); err != nil {
getLog(ctx).WithError(err).Error("error closing connection after label write error")
}
}
if dl, ok := ctx.Deadline(); ok {
defer conn.SetDeadline(time.Time{})
if err := conn.SetDeadline(dl); err != nil {
closeConn(err)
return nil, err
}
}
n, err := conn.Write(c.label)
if err != nil {
closeConn(err)
return nil, err
}
if n != len(c.label) {
closeConn(fmt.Errorf("short label write"))
return nil, io.ErrShortWrite
}
return conn, nil
}
func MuxConnecter(rawConnecter transport.Connecter, labels []string, timeout time.Duration) (map[string]transport.Connecter, error) {
ret := make(map[string]transport.Connecter, len(labels))
for _, label := range labels {
var paddedLabel [LabelLen]byte
if err := padLabel(paddedLabel[:], label); err != nil {
return nil, err
}
lc := &labeledConnecter{paddedLabel[:], rawConnecter}
if _, ok := ret[label]; ok {
return nil, fmt.Errorf("duplicate label %q", label)
}
ret[label] = lc
}
return ret, nil
}

View File

@ -0,0 +1,181 @@
// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener}
// to add an exchange of protocol version information on connection establishment.
//
// The protocol version information (banner) is plain text, thus making it
// easy to diagnose issues with standard tools.
package versionhandshake
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"time"
"unicode/utf8"
)
type HandshakeMessage struct {
ProtocolVersion int
Extensions []string
}
// A HandshakeError describes what went wrong during the handshake.
// It implements net.Error and is always temporary.
type HandshakeError struct {
msg string
// If not nil, the underlying IO error that caused the handshake to fail.
IOError error
}
var _ net.Error = &HandshakeError{}
func (e HandshakeError) Error() string { return e.msg }
// Always true to enable usage in a net.Listener.
func (e HandshakeError) Temporary() bool { return true }
// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value.
// Otherwise false.
func (e HandshakeError) Timeout() bool {
if neterr, ok := e.IOError.(net.Error); ok {
return neterr.Timeout()
}
return false
}
func hsErr(format string, args... interface{}) *HandshakeError {
return &HandshakeError{msg: fmt.Sprintf(format, args...)}
}
func hsIOErr(err error, format string, args... interface{}) *HandshakeError {
return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)}
}
// MaxProtocolVersion is the maximum allowed protocol version.
// This is a protocol constant, changing it may break the wire format.
const MaxProtocolVersion = 9999
// Only returns *HandshakeError as error.
func (m *HandshakeMessage) Encode() ([]byte, error) {
if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion))
}
if len(m.Extensions) >= MaxProtocolVersion {
return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion))
}
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
var extensions strings.Builder
for i, ext := range m.Extensions {
if strings.ContainsAny(ext, "\n") {
return nil, hsErr("Extension #%d contains forbidden newline character", i)
}
if !utf8.ValidString(ext) {
return nil, hsErr("Extension #%d is not valid UTF-8", i)
}
extensions.WriteString(ext)
extensions.WriteString("\n")
}
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
m.ProtocolVersion, len(m.Extensions), extensions.String())
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
return []byte(withLen), nil
}
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
var lenAndSpace [11]byte
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
return hsIOErr(err, "error reading protocol banner length: %s", err)
}
if !utf8.Valid(lenAndSpace[:]) {
return hsErr("invalid start of handshake message: not valid UTF-8")
}
var followLen int
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
if n != 1 || err != nil {
return hsErr("could not parse handshake message length")
}
if followLen > maxLen {
return hsErr("handshake message length exceeds max length (%d vs %d)",
followLen, maxLen)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
if err != nil {
return hsIOErr(err, "error reading protocol banner body: %s", err)
}
var (
protoVersion, extensionCount int
)
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
&protoVersion, &extensionCount)
if n != 2 || err != nil {
return hsErr("could not parse handshake message: %s", err)
}
if protoVersion < 1 {
return hsErr("invalid protocol version %q", protoVersion)
}
m.ProtocolVersion = protoVersion
if extensionCount < 0 {
return hsErr("invalid extension count %q", extensionCount)
}
if extensionCount == 0 {
if buf.Len() != 0 {
return hsErr("unexpected data trailing after header")
}
m.Extensions = nil
return nil
}
s := buf.String()
if strings.Count(s, "\n") != extensionCount {
return hsErr("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
}
exts := strings.Split(s, "\n")
if exts[len(exts)-1] != "" {
return hsErr("unexpected data trailing after last extension newline")
}
m.Extensions = exts[0:len(exts)-1]
return nil
}
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
// current protocol version is hardcoded here
return DoHandshakeVersion(conn, deadline, 1)
}
const HandshakeMessageMaxLen = 16 * 4096
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
ours := HandshakeMessage{
ProtocolVersion: version,
Extensions: nil,
}
hsb, err := ours.Encode()
if err != nil {
return hsErr("could not encode protocol banner: %s", err)
}
defer conn.SetDeadline(time.Time{})
conn.SetDeadline(deadline)
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
if err != nil {
return hsErr("could not send protocol banner: %s", err)
}
theirs := HandshakeMessage{}
if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil {
return hsErr("could not decode protocol banner: %s", err)
}
if theirs.ProtocolVersion != ours.ProtocolVersion {
return hsErr("protocol versions do not match: ours is %d, theirs is %d",
ours.ProtocolVersion, theirs.ProtocolVersion)
}
// ignore extensions, we don't use them
return nil
}

View File

@ -1,4 +1,4 @@
package transport package versionhandshake
import ( import (
"bytes" "bytes"

View File

@ -0,0 +1,66 @@
package versionhandshake
import (
"context"
"net"
"time"
"github.com/zrepl/zrepl/transport"
)
type HandshakeConnecter struct {
connecter transport.Connecter
timeout time.Duration
}
func (c HandshakeConnecter) Connect(ctx context.Context) (transport.Wire, error) {
conn, err := c.connecter.Connect(ctx)
if err != nil {
return nil, err
}
dl, ok := ctx.Deadline()
if !ok {
dl = time.Now().Add(c.timeout)
}
if err := DoHandshakeCurrentVersion(conn, dl); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func Connecter(connecter transport.Connecter, timeout time.Duration) HandshakeConnecter {
return HandshakeConnecter{
connecter: connecter,
timeout: timeout,
}
}
// wrapper type that performs a a protocol version handshake before returning the connection
type HandshakeListener struct {
l transport.AuthenticatedListener
timeout time.Duration
}
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
func (l HandshakeListener) Close() error { return l.l.Close() }
func (l HandshakeListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
conn, err := l.l.Accept(ctx)
if err != nil {
return nil, err
}
dl, ok := ctx.Deadline()
if !ok {
dl = time.Now().Add(l.timeout) // shadowing
}
if err := DoHandshakeCurrentVersion(conn, dl); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func Listener(l transport.AuthenticatedListener, timeout time.Duration) transport.AuthenticatedListener {
return HandshakeListener{l, timeout}
}

View File

@ -4,8 +4,11 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"os"
"time" "time"
) )
@ -22,12 +25,13 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
} }
type ClientAuthListener struct { type ClientAuthListener struct {
l net.Listener l *net.TCPListener
c *tls.Config
handshakeTimeout time.Duration handshakeTimeout time.Duration
} }
func NewClientAuthListener( func NewClientAuthListener(
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate, l *net.TCPListener, ca *x509.CertPool, serverCert tls.Certificate,
handshakeTimeout time.Duration) *ClientAuthListener { handshakeTimeout time.Duration) *ClientAuthListener {
if ca == nil { if ca == nil {
@ -37,29 +41,35 @@ func NewClientAuthListener(
panic(serverCert) panic(serverCert)
} }
tlsConf := tls.Config{ tlsConf := &tls.Config{
Certificates: []tls.Certificate{serverCert}, Certificates: []tls.Certificate{serverCert},
ClientCAs: ca, ClientCAs: ca,
ClientAuth: tls.RequireAndVerifyClientCert, ClientAuth: tls.RequireAndVerifyClientCert,
PreferServerCipherSuites: true, PreferServerCipherSuites: true,
KeyLogWriter: keylogFromEnv(),
} }
l = tls.NewListener(l, &tlsConf)
return &ClientAuthListener{ return &ClientAuthListener{
l, l,
tlsConf,
handshakeTimeout, handshakeTimeout,
} }
} }
func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { // Accept() accepts a connection from the *net.TCPListener passed to the constructor
c, err = l.l.Accept() // and sets up the TLS connection, including handshake and peer CommmonName validation
// within the specified handshakeTimeout.
//
// It returns both the raw TCP connection (tcpConn) and the TLS connection (tlsConn) on top of it.
// Access to the raw tcpConn might be necessary if CloseWrite semantics are desired:
// tlsConn.CloseWrite does NOT call tcpConn.CloseWrite, hence we provide access to tcpConn to
// allow the caller to do this by themselves.
func (l *ClientAuthListener) Accept() (tcpConn *net.TCPConn, tlsConn *tls.Conn, clientCN string, err error) {
tcpConn, err = l.l.AcceptTCP()
if err != nil { if err != nil {
return nil, "", err return nil, nil, "", err
}
tlsConn, ok := c.(*tls.Conn)
if !ok {
return c, "", err
} }
tlsConn = tls.Server(tcpConn, l.c)
var ( var (
cn string cn string
peerCerts []*x509.Certificate peerCerts []*x509.Certificate
@ -70,6 +80,7 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
if err = tlsConn.Handshake(); err != nil { if err = tlsConn.Handshake(); err != nil {
goto CloseAndErr goto CloseAndErr
} }
tlsConn.SetDeadline(time.Time{})
peerCerts = tlsConn.ConnectionState().PeerCertificates peerCerts = tlsConn.ConnectionState().PeerCertificates
if len(peerCerts) < 1 { if len(peerCerts) < 1 {
@ -77,10 +88,11 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
goto CloseAndErr goto CloseAndErr
} }
cn = peerCerts[0].Subject.CommonName cn = peerCerts[0].Subject.CommonName
return c, cn, nil return tcpConn, tlsConn, cn, nil
CloseAndErr: CloseAndErr:
c.Close() // unlike CloseWrite, Close on *tls.Conn actually closes the underlying connection
return nil, "", err tlsConn.Close() // TODO log error
return nil, nil, "", err
} }
func (l *ClientAuthListener) Addr() net.Addr { func (l *ClientAuthListener) Addr() net.Addr {
@ -105,7 +117,21 @@ func ClientAuthClient(serverName string, rootCA *x509.CertPool, clientCert tls.C
Certificates: []tls.Certificate{clientCert}, Certificates: []tls.Certificate{clientCert},
RootCAs: rootCA, RootCAs: rootCA,
ServerName: serverName, ServerName: serverName,
KeyLogWriter: keylogFromEnv(),
} }
tlsConfig.BuildNameToCertificate() tlsConfig.BuildNameToCertificate()
return tlsConfig, nil return tlsConfig, nil
} }
func keylogFromEnv() io.Writer {
var keyLog io.Writer = nil
if outfile := os.Getenv("ZREPL_KEYLOG_FILE"); outfile != "" {
fmt.Fprintf(os.Stderr, "writing to key log %s\n", outfile)
var err error
keyLog, err = os.OpenFile(outfile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
panic(err)
}
}
return keyLog
}

View File

@ -0,0 +1,58 @@
// Package fromconfig instantiates transports based on zrepl config structures
// (see package config).
package fromconfig
import (
"fmt"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/transport/local"
"github.com/zrepl/zrepl/transport/ssh"
"github.com/zrepl/zrepl/transport/tcp"
"github.com/zrepl/zrepl/transport/tls"
)
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory,error) {
var (
l transport.AuthenticatedListenerFactory
err error
)
switch v := in.Ret.(type) {
case *config.TCPServe:
l, err = tcp.TCPListenerFactoryFromConfig(g, v)
case *config.TLSServe:
l, err = tls.TLSListenerFactoryFromConfig(g, v)
case *config.StdinserverServer:
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
case *config.LocalServe:
l, err = local.LocalListenerFactoryFromConfig(g, v)
default:
return nil, errors.Errorf("internal error: unknown serve type %T", v)
}
return l, err
}
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) {
var (
connecter transport.Connecter
err error
)
switch v := in.Ret.(type) {
case *config.SSHStdinserverConnect:
connecter, err = ssh.SSHStdinserverConnecterFromConfig(v)
case *config.TCPConnect:
connecter, err = tcp.TCPConnecterFromConfig(v)
case *config.TLSConnect:
connecter, err = tls.TLSConnecterFromConfig(v)
case *config.LocalConnect:
connecter, err = local.LocalConnecterFromConfig(v)
default:
panic(fmt.Sprintf("implementation error: unknown connecter type %T", v))
}
return connecter, err
}

View File

@ -1,11 +1,10 @@
package connecter package local
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/transport/serve" "github.com/zrepl/zrepl/transport"
"net"
) )
type LocalConnecter struct { type LocalConnecter struct {
@ -23,8 +22,8 @@ func LocalConnecterFromConfig(in *config.LocalConnect) (*LocalConnecter, error)
return &LocalConnecter{listenerName: in.ListenerName, clientIdentity: in.ClientIdentity}, nil return &LocalConnecter{listenerName: in.ListenerName, clientIdentity: in.ClientIdentity}, nil
} }
func (c *LocalConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { func (c *LocalConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
l := serve.GetLocalListener(c.listenerName) l := GetLocalListener(c.listenerName)
return l.Connect(dialCtx, c.clientIdentity) return l.Connect(dialCtx, c.clientIdentity)
} }

View File

@ -1,4 +1,4 @@
package serve package local
import ( import (
"context" "context"
@ -7,6 +7,7 @@ import (
"github.com/zrepl/zrepl/util/socketpair" "github.com/zrepl/zrepl/util/socketpair"
"net" "net"
"sync" "sync"
"github.com/zrepl/zrepl/transport"
) )
var localListeners struct { var localListeners struct {
@ -39,7 +40,7 @@ type connectRequest struct {
} }
type connectResult struct { type connectResult struct {
conn net.Conn conn transport.Wire
err error err error
} }
@ -54,7 +55,7 @@ func newLocalListener() *LocalListener {
} }
// Connect to the LocalListener from a client with identity clientIdentity // Connect to the LocalListener from a client with identity clientIdentity
func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn net.Conn, err error) { func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn transport.Wire, err error) {
// place request // place request
req := connectRequest{ req := connectRequest{
@ -89,21 +90,14 @@ func (a localAddr) String() string { return a.S }
func (l *LocalListener) Addr() (net.Addr) { return localAddr{"<listening>"} } func (l *LocalListener) Addr() (net.Addr) { return localAddr{"<listening>"} }
type localConn struct { func (l *LocalListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
net.Conn
clientIdentity string
}
func (l localConn) ClientIdentity() string { return l.clientIdentity }
func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
respondToRequest := func(req connectRequest, res connectResult) (err error) { respondToRequest := func(req connectRequest, res connectResult) (err error) {
getLogger(ctx). transport.GetLogger(ctx).
WithField("res.conn", res.conn).WithField("res.err", res.err). WithField("res.conn", res.conn).WithField("res.err", res.err).
Debug("responding to client request") Debug("responding to client request")
defer func() { defer func() {
errv := recover() errv := recover()
getLogger(ctx).WithField("recover_err", errv). transport.GetLogger(ctx).WithField("recover_err", errv).
Debug("panic on send to client callback, likely a legitimate client-side timeout") Debug("panic on send to client callback, likely a legitimate client-side timeout")
}() }()
select { select {
@ -116,7 +110,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
return err return err
} }
getLogger(ctx).Debug("waiting for local client connect requests") transport.GetLogger(ctx).Debug("waiting for local client connect requests")
var req connectRequest var req connectRequest
select { select {
case req = <-l.connects: case req = <-l.connects:
@ -124,7 +118,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
return nil, ctx.Err() return nil, ctx.Err()
} }
getLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request") transport.GetLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request")
if req.clientIdentity == "" { if req.clientIdentity == "" {
res := connectResult{nil, fmt.Errorf("client identity must not be empty")} res := connectResult{nil, fmt.Errorf("client identity must not be empty")}
if err := respondToRequest(req, res); err != nil { if err := respondToRequest(req, res); err != nil {
@ -133,31 +127,31 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
return nil, fmt.Errorf("client connected with empty client identity") return nil, fmt.Errorf("client connected with empty client identity")
} }
getLogger(ctx).Debug("creating socketpair") transport.GetLogger(ctx).Debug("creating socketpair")
left, right, err := socketpair.SocketPair() left, right, err := socketpair.SocketPair()
if err != nil { if err != nil {
res := connectResult{nil, fmt.Errorf("server error: %s", err)} res := connectResult{nil, fmt.Errorf("server error: %s", err)}
if respErr := respondToRequest(req, res); respErr != nil { if respErr := respondToRequest(req, res); respErr != nil {
// returning the socketpair error properly is more important than the error sent to the client // returning the socketpair error properly is more important than the error sent to the client
getLogger(ctx).WithError(respErr).Error("error responding to client") transport.GetLogger(ctx).WithError(respErr).Error("error responding to client")
} }
return nil, err return nil, err
} }
getLogger(ctx).Debug("responding with left side of socketpair") transport.GetLogger(ctx).Debug("responding with left side of socketpair")
res := connectResult{left, nil} res := connectResult{left, nil}
if err := respondToRequest(req, res); err != nil { if err := respondToRequest(req, res); err != nil {
getLogger(ctx).WithError(err).Error("error responding to client") transport.GetLogger(ctx).WithError(err).Error("error responding to client")
if err := left.Close(); err != nil { if err := left.Close(); err != nil {
getLogger(ctx).WithError(err).Error("cannot close left side of socketpair") transport.GetLogger(ctx).WithError(err).Error("cannot close left side of socketpair")
} }
if err := right.Close(); err != nil { if err := right.Close(); err != nil {
getLogger(ctx).WithError(err).Error("cannot close right side of socketpair") transport.GetLogger(ctx).WithError(err).Error("cannot close right side of socketpair")
} }
return nil, err return nil, err
} }
return localConn{right, req.clientIdentity}, nil return transport.NewAuthConn(right, req.clientIdentity), nil
} }
func (l *LocalListener) Close() error { func (l *LocalListener) Close() error {
@ -169,19 +163,13 @@ func (l *LocalListener) Close() error {
return nil return nil
} }
type LocalListenerFactory struct { func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (transport.AuthenticatedListenerFactory,error) {
listenerName string
}
func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (f *LocalListenerFactory, err error) {
if in.ListenerName == "" { if in.ListenerName == "" {
return nil, fmt.Errorf("ListenerName must not be empty") return nil, fmt.Errorf("ListenerName must not be empty")
} }
return &LocalListenerFactory{listenerName: in.ListenerName}, nil listenerName := in.ListenerName
lf := func() (transport.AuthenticatedListener,error) {
return GetLocalListener(listenerName), nil
} }
return lf, nil
func (lf *LocalListenerFactory) Listen() (AuthenticatedListener, error) {
return GetLocalListener(lf.listenerName), nil
} }

View File

@ -1,13 +1,12 @@
package connecter package ssh
import ( import (
"context" "context"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/problame/go-netssh" "github.com/problame/go-netssh"
"github.com/problame/go-streamrpc"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"net" "github.com/zrepl/zrepl/transport"
"time" "time"
) )
@ -22,8 +21,6 @@ type SSHStdinserverConnecter struct {
dialTimeout time.Duration dialTimeout time.Duration
} }
var _ streamrpc.Connecter = &SSHStdinserverConnecter{}
func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSHStdinserverConnecter, err error) { func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSHStdinserverConnecter, err error) {
c = &SSHStdinserverConnecter{ c = &SSHStdinserverConnecter{
@ -39,15 +36,7 @@ func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSH
} }
type netsshConnToConn struct{ *netssh.SSHConn } func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
var _ net.Conn = netsshConnToConn{}
func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil }
func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil }
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) {
var endpoint netssh.Endpoint var endpoint netssh.Endpoint
if err := copier.Copy(&endpoint, c); err != nil { if err := copier.Copy(&endpoint, c); err != nil {
@ -62,5 +51,5 @@ func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, er
} }
return nil, err return nil, err
} }
return netsshConnToConn{nconn}, nil return nconn, nil
} }

View File

@ -1,50 +1,38 @@
package serve package ssh
import ( import (
"github.com/problame/go-netssh" "github.com/problame/go-netssh"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/nethelpers" "github.com/zrepl/zrepl/daemon/nethelpers"
"io" "github.com/zrepl/zrepl/transport"
"fmt"
"net" "net"
"path" "path"
"time"
"context" "context"
"github.com/pkg/errors" "github.com/pkg/errors"
"sync/atomic" "sync/atomic"
) )
type StdinserverListenerFactory struct { func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory,error) {
ClientIdentities []string
Sockdir string
}
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
for _, ci := range in.ClientIdentities { for _, ci := range in.ClientIdentities {
if err := ValidateClientIdentity(ci); err != nil { if err := transport.ValidateClientIdentity(ci); err != nil {
return nil, errors.Wrapf(err, "invalid client identity %q", ci) return nil, errors.Wrapf(err, "invalid client identity %q", ci)
} }
} }
f = &multiStdinserverListenerFactory{ clientIdentities := in.ClientIdentities
ClientIdentities: in.ClientIdentities, sockdir := g.Serve.StdinServer.SockDir
Sockdir: g.Serve.StdinServer.SockDir,
lf := func() (transport.AuthenticatedListener,error) {
return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities)
} }
return return lf, nil
}
type multiStdinserverListenerFactory struct {
ClientIdentities []string
Sockdir string
}
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
} }
type multiStdinserverAcceptRes struct { type multiStdinserverAcceptRes struct {
conn AuthenticatedConn conn *transport.AuthConn
err error err error
} }
@ -78,7 +66,7 @@ func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string)
return &MultiStdinserverListener{listeners: listeners}, nil return &MultiStdinserverListener{listeners: listeners}, nil
} }
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){ func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error){
if m.accepts == nil { if m.accepts == nil {
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners)) m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
@ -97,8 +85,22 @@ func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedCon
} }
func (m *MultiStdinserverListener) Addr() (net.Addr) { type multiListenerAddr struct {
return netsshAddr{} clients []string
}
func (multiListenerAddr) Network() string { return "netssh" }
func (l multiListenerAddr) String() string {
return fmt.Sprintf("netssh:clients=%v", l.clients)
}
func (m *MultiStdinserverListener) Addr() net.Addr {
cis := make([]string, len(m.listeners))
for i := range cis {
cis[i] = m.listeners[i].clientIdentity
}
return multiListenerAddr{cis}
} }
func (m *MultiStdinserverListener) Close() error { func (m *MultiStdinserverListener) Close() error {
@ -118,41 +120,28 @@ type stdinserverListener struct {
clientIdentity string clientIdentity string
} }
func (l stdinserverListener) Addr() net.Addr { type listenerAddr struct {
return netsshAddr{} clientIdentity string
} }
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) { func (listenerAddr) Network() string { return "netssh" }
func (a listenerAddr) String() string {
return fmt.Sprintf("netssh:client=%q", a.clientIdentity)
}
func (l stdinserverListener) Addr() net.Addr {
return listenerAddr{l.clientIdentity}
}
func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
c, err := l.l.Accept() c, err := l.l.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil return transport.NewAuthConn(c, l.clientIdentity), nil
} }
func (l stdinserverListener) Close() (err error) { func (l stdinserverListener) Close() (err error) {
return l.l.Close() return l.l.Close()
} }
type netsshAddr struct{}
func (netsshAddr) Network() string { return "netssh" }
func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
clientIdentity string
}
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
// FIXME log warning once!
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -1,9 +1,11 @@
package connecter package tcp
import ( import (
"context" "context"
"github.com/zrepl/zrepl/config"
"net" "net"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport"
) )
type TCPConnecter struct { type TCPConnecter struct {
@ -19,6 +21,10 @@ func TCPConnecterFromConfig(in *config.TCPConnect) (*TCPConnecter, error) {
return &TCPConnecter{in.Address, dialer}, nil return &TCPConnecter{in.Address, dialer}, nil
} }
func (c *TCPConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { func (c *TCPConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
return c.dialer.DialContext(dialCtx, "tcp", c.Address) conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address)
if err != nil {
return nil, err
}
return conn.(*net.TCPConn), nil
} }

View File

@ -1,17 +1,13 @@
package serve package tcp
import ( import (
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"net" "net"
"github.com/pkg/errors" "github.com/pkg/errors"
"context" "context"
"github.com/zrepl/zrepl/transport"
) )
type TCPListenerFactory struct {
address *net.TCPAddr
clientMap *ipMap
}
type ipMapEntry struct { type ipMapEntry struct {
ip net.IP ip net.IP
ident string ident string
@ -28,7 +24,7 @@ func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
if clientIP == nil { if clientIP == nil {
return nil, errors.Errorf("cannot parse client IP %q", clientIPString) return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
} }
if err := ValidateClientIdentity(clientIdent); err != nil { if err := transport.ValidateClientIdentity(clientIdent); err != nil {
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString) return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
} }
entries = append(entries, ipMapEntry{clientIP, clientIdent}) entries = append(entries, ipMapEntry{clientIP, clientIdent})
@ -45,7 +41,7 @@ func (m *ipMap) Get(ip net.IP) (string, error) {
return "", errors.Errorf("no identity mapping for client IP %s", ip) return "", errors.Errorf("no identity mapping for client IP %s", ip)
} }
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) { func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (transport.AuthenticatedListenerFactory, error) {
addr, err := net.ResolveTCPAddr("tcp", in.Listen) addr, err := net.ResolveTCPAddr("tcp", in.Listen)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot parse listen address") return nil, errors.Wrap(err, "cannot parse listen address")
@ -54,19 +50,14 @@ func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPLi
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot parse client IP map") return nil, errors.Wrap(err, "cannot parse client IP map")
} }
lf := &TCPListenerFactory{ lf := func() (transport.AuthenticatedListener, error) {
address: addr, l, err := net.ListenTCP("tcp", addr)
clientMap: clientMap,
}
return lf, nil
}
func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := net.ListenTCP("tcp", f.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TCPAuthListener{l, f.clientMap}, nil return &TCPAuthListener{l, clientMap}, nil
}
return lf, nil
} }
type TCPAuthListener struct { type TCPAuthListener struct {
@ -74,18 +65,18 @@ type TCPAuthListener struct {
clientMap *ipMap clientMap *ipMap
} }
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { func (f *TCPAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
nc, err := f.TCPListener.Accept() nc, err := f.TCPListener.AcceptTCP()
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
clientIdent, err := f.clientMap.Get(clientIP) clientIdent, err := f.clientMap.Get(clientIP)
if err != nil { if err != nil {
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") transport.GetLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
nc.Close() nc.Close()
return nil, err return nil, err
} }
return authConn{nc, clientIdent}, nil return transport.NewAuthConn(nc, clientIdent), nil
} }

View File

@ -1,12 +1,14 @@
package connecter package tls
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/tlsconf" "github.com/zrepl/zrepl/tlsconf"
"net" "github.com/zrepl/zrepl/transport"
) )
type TLSConnecter struct { type TLSConnecter struct {
@ -38,10 +40,12 @@ func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) {
return &TLSConnecter{in.Address, dialer, tlsConfig}, nil return &TLSConnecter{in.Address, dialer, tlsConfig}, nil
} }
func (c *TLSConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { func (c *TLSConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
conn, err = c.dialer.DialContext(dialCtx, "tcp", c.Address) conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tls.Client(conn, c.tlsConfig), nil tcpConn := conn.(*net.TCPConn)
tlsConn := tls.Client(conn, c.tlsConfig)
return newWireAdaptor(tlsConn, tcpConn), nil
} }

View File

@ -0,0 +1,89 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/tlsconf"
"net"
"time"
"context"
)
type TLSListenerFactory struct {
address string
clientCA *x509.CertPool
serverCert tls.Certificate
handshakeTimeout time.Duration
clientCNs map[string]struct{}
}
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory,error) {
address := in.Listen
handshakeTimeout := in.HandshakeTimeout
if in.Ca == "" || in.Cert == "" || in.Key == "" {
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
}
clientCA, err := tlsconf.ParseCAFile(in.Ca)
if err != nil {
return nil, errors.Wrap(err, "cannot parse ca file")
}
serverCert, err := tls.LoadX509KeyPair(in.Cert, in.Key)
if err != nil {
return nil, errors.Wrap(err, "cannot parse cer/key pair")
}
clientCNs := make(map[string]struct{}, len(in.ClientCNs))
for i, cn := range in.ClientCNs {
if err := transport.ValidateClientIdentity(cn); err != nil {
return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn)
}
// dupes are ok fr now
clientCNs[cn] = struct{}{}
}
lf := func() (transport.AuthenticatedListener, error) {
l, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
tcpL := l.(*net.TCPListener)
tl := tlsconf.NewClientAuthListener(tcpL, clientCA, serverCert, handshakeTimeout)
return &tlsAuthListener{tl, clientCNs}, nil
}
return lf, nil
}
type tlsAuthListener struct {
*tlsconf.ClientAuthListener
clientCNs map[string]struct{}
}
func (l tlsAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
tcpConn, tlsConn, cn, err := l.ClientAuthListener.Accept()
if err != nil {
return nil, err
}
if _, ok := l.clientCNs[cn]; !ok {
if dl, ok := ctx.Deadline(); ok {
defer tlsConn.SetDeadline(time.Time{})
tlsConn.SetDeadline(dl)
}
if err := tlsConn.Close(); err != nil {
transport.GetLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name")
}
return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, tlsConn.RemoteAddr())
}
adaptor := newWireAdaptor(tlsConn, tcpConn)
return transport.NewAuthConn(adaptor, cn), nil
}

View File

@ -0,0 +1,47 @@
package tls
import (
"crypto/tls"
"fmt"
"net"
"os"
)
// adapts a tls.Conn and its underlying net.TCPConn into a valid transport.Wire
type transportWireAdaptor struct {
*tls.Conn
tcpConn *net.TCPConn
}
func newWireAdaptor(tlsConn *tls.Conn, tcpConn *net.TCPConn) transportWireAdaptor {
return transportWireAdaptor{tlsConn, tcpConn}
}
// CloseWrite implements transport.Wire.CloseWrite which is different from *tls.Conn.CloseWrite:
// the former requires that the other side observes io.EOF, but *tls.Conn.CloseWrite does not
// close the underlying connection so no io.EOF would be observed.
func (w transportWireAdaptor) CloseWrite() error {
if err := w.Conn.CloseWrite(); err != nil {
// TODO log error
fmt.Fprintf(os.Stderr, "transport/tls.CloseWrite() error: %s\n", err)
}
return w.tcpConn.CloseWrite()
}
// Close implements transport.Wire.Close which is different from a *tls.Conn.Close:
// At the time of writing (Go 1.11), closing tls.Conn closes the TCP connection immediately,
// which results in io.ErrUnexpectedEOF on the other side.
// We assume that w.Conn has a deadline set for the close, so the CloseWrite will time out if it blocks,
// falling through to the actual Close()
func (w transportWireAdaptor) Close() error {
// var buf [1<<15]byte
// w.Conn.Write(buf[:])
// CloseWrite will send a TLS alert record down the line which
// in the Go implementation acts like a flush...?
// if err := w.Conn.CloseWrite(); err != nil {
// // TODO log error
// fmt.Fprintf(os.Stderr, "transport/tls.Close() close write error: %s\n", err)
// }
// time.Sleep(1 * time.Second)
return w.Conn.Close()
}

84
transport/transport.go Normal file
View File

@ -0,0 +1,84 @@
// Package transport defines a common interface for
// network connections that have an associated client identity.
package transport
import (
"context"
"errors"
"net"
"syscall"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
"github.com/zrepl/zrepl/zfs"
)
type AuthConn struct {
Wire
clientIdentity string
}
var _ timeoutconn.SyscallConner = AuthConn{}
func (a AuthConn) SyscallConn() (rawConn syscall.RawConn, err error) {
scc, ok := a.Wire.(timeoutconn.SyscallConner)
if !ok {
return nil, timeoutconn.SyscallConnNotSupported
}
return scc.SyscallConn()
}
func NewAuthConn(conn Wire, clientIdentity string) *AuthConn {
return &AuthConn{conn, clientIdentity}
}
func (c *AuthConn) ClientIdentity() string {
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
panic(err)
}
return c.clientIdentity
}
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
type AuthenticatedListener interface {
Addr() net.Addr
Accept(ctx context.Context) (*AuthConn, error)
Close() error
}
type AuthenticatedListenerFactory func() (AuthenticatedListener, error)
type Wire = timeoutconn.Wire
type Connecter interface {
Connect(ctx context.Context) (Wire, error)
}
// A client identity must be a single component in a ZFS filesystem path
func ValidateClientIdentity(in string) (err error) {
path, err := zfs.NewDatasetPath(in)
if err != nil {
return err
}
if path.Length() != 1 {
return errors.New("client identity must be a single path comonent (not empty, no '/')")
}
return nil
}
type contextKey int
const contextKeyLog contextKey = 0
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func GetLogger(ctx context.Context) Logger {
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
return log
}
return logger.NewNullLogger()
}

View File

@ -0,0 +1,71 @@
package bytecounter
import (
"io"
"sync/atomic"
"github.com/zrepl/zrepl/zfs"
)
// StreamCopier wraps a zfs.StreamCopier, reimplemening
// its interface and counting the bytes written to during copying.
type StreamCopier interface {
zfs.StreamCopier
Count() int64
}
// NewStreamCopier wraps sc into a StreamCopier.
// If sc is io.Reader, it is guaranteed that the returned StreamCopier
// implements that interface, too.
func NewStreamCopier(sc zfs.StreamCopier) StreamCopier {
bsc := &streamCopier{sc, 0}
if scr, ok := sc.(io.Reader); ok {
return streamCopierAndReader{bsc, scr}
} else {
return bsc
}
}
type streamCopier struct {
sc zfs.StreamCopier
count int64
}
// proxy writer used by streamCopier
type streamCopierWriter struct {
parent *streamCopier
w io.Writer
}
func (w streamCopierWriter) Write(p []byte) (n int, err error) {
n, err = w.w.Write(p)
atomic.AddInt64(&w.parent.count, int64(n))
return
}
func (s *streamCopier) Count() int64 {
return atomic.LoadInt64(&s.count)
}
var _ zfs.StreamCopier = &streamCopier{}
func (s streamCopier) Close() error {
return s.sc.Close()
}
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
ww := streamCopierWriter{s, w}
return s.sc.WriteStreamTo(ww)
}
// a streamCopier whose underlying sc is an io.Reader
type streamCopierAndReader struct {
*streamCopier
asReader io.Reader
}
func (scr streamCopierAndReader) Read(p []byte) (int, error) {
n, err := scr.asReader.Read(p)
atomic.AddInt64(&scr.streamCopier.count, int64(n))
return n, err
}

View File

@ -0,0 +1,38 @@
package bytecounter
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/zfs"
)
type mockStreamCopierAndReader struct {
zfs.StreamCopier // to satisfy interface
reads int
}
func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) {
r.reads++
return len(p), nil
}
var _ io.Reader = &mockStreamCopierAndReader{}
func TestNewStreamCopierReexportsReader(t *testing.T) {
mock := &mockStreamCopierAndReader{}
x := NewStreamCopier(mock)
r, ok := x.(io.Reader)
if !ok {
t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x)
}
var buf [23]byte
n, err := r.Read(buf[:])
assert.True(t, mock.reads == 1)
assert.True(t, n == len(buf))
assert.NoError(t, err)
assert.True(t, x.Count() == 23)
}

14
util/devnoop/devnoop.go Normal file
View File

@ -0,0 +1,14 @@
// package devnoop provides an io.ReadWriteCloser that never errors
// and always reports reads / writes to / from buffers as complete.
// The buffers themselves are never touched.
package devnoop
type Dev struct{}
func Get() Dev {
return Dev{}
}
func (Dev) Write(p []byte) (n int, err error) { return len(p), nil }
func (Dev) Read(p []byte) (n int, err error) { return len(p), nil }
func (Dev) Close() error { return nil }

View File

@ -2,6 +2,7 @@ package envconst
import ( import (
"os" "os"
"strconv"
"sync" "sync"
"time" "time"
) )
@ -23,3 +24,19 @@ func Duration(varname string, def time.Duration) time.Duration {
cache.Store(varname, d) cache.Store(varname, d)
return d return d
} }
func Int64(varname string, def int64) int64 {
if v, ok := cache.Load(varname); ok {
return v.(int64)
}
e := os.Getenv(varname)
if e == "" {
return def
}
d, err := strconv.ParseInt(e, 10, 64)
if err != nil {
panic(err)
}
cache.Store(varname, d)
return d
}

View File

@ -4,15 +4,18 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/zrepl/zrepl/util/envconst"
"io" "io"
"os" "os"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
) )
// An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. // An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface.
type IOCommand struct { type IOCommand struct {
Cmd *exec.Cmd Cmd *exec.Cmd
kill context.CancelFunc
Stdin io.WriteCloser Stdin io.WriteCloser
Stdout io.ReadCloser Stdout io.ReadCloser
StderrBuf *bytes.Buffer StderrBuf *bytes.Buffer
@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS
c = &IOCommand{} c = &IOCommand{}
ctx, c.kill = context.WithCancel(ctx)
c.Cmd = exec.CommandContext(ctx, command, args...) c.Cmd = exec.CommandContext(ctx, command, args...)
if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil {
@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) {
func (c *IOCommand) Read(buf []byte) (n int, err error) { func (c *IOCommand) Read(buf []byte) (n int, err error) {
n, err = c.Stdout.Read(buf) n, err = c.Stdout.Read(buf)
if err == io.EOF { if err == io.EOF {
if waitErr := c.doWait(); waitErr != nil { if waitErr := c.doWait(context.Background()); waitErr != nil {
err = waitErr err = waitErr
} }
} }
return return
} }
func (c *IOCommand) doWait() (err error) { func (c *IOCommand) doWait(ctx context.Context) (err error) {
go func() {
dl, ok := ctx.Deadline()
if !ok {
return
}
time.Sleep(dl.Sub(time.Now()))
c.kill()
c.Stdout.Close()
c.Stdin.Close()
}()
waitErr := c.Cmd.Wait() waitErr := c.Cmd.Wait()
var wasUs bool = false var wasUs bool = false
var waitStatus syscall.WaitStatus var waitStatus syscall.WaitStatus
@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) {
if c.Cmd.ProcessState == nil { if c.Cmd.ProcessState == nil {
// racy... // racy...
err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM) err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM)
if err != nil { ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second))
return defer cancel()
} return c.doWait(ctx)
return c.doWait()
} else { } else {
return c.ExitResult.Error return c.ExitResult.Error
} }

View File

@ -1,42 +1,32 @@
package socketpair package socketpair
import ( import (
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"golang.org/x/sys/unix"
) )
type fileConn struct {
net.Conn // net.FileConn
f *os.File
}
func (c fileConn) Close() error { func SocketPair() (a, b *net.UnixConn, err error) {
if err := c.Conn.Close(); err != nil {
return err
}
if err := c.f.Close(); err != nil {
return err
}
return nil
}
func SocketPair() (a, b net.Conn, err error) {
// don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on // don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on
sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
toConn := func(fd int) (net.Conn, error) { toConn := func(fd int) (*net.UnixConn, error) {
f := os.NewFile(uintptr(fd), "fileconn") f := os.NewFile(uintptr(fd), "fileconn")
if f == nil { if f == nil {
panic(fd) panic(fd)
} }
c, err := net.FileConn(f) c, err := net.FileConn(f)
f.Close() // net.FileConn uses dup under the hood
if err != nil { if err != nil {
f.Close()
return nil, err return nil, err
} }
return fileConn{Conn: c, f: f}, nil // strictly, the following type assertion is an implementation detail
// however, will be caught by test TestSocketPairWorks
fileConnIsUnixConn := c.(*net.UnixConn)
return fileConnIsUnixConn, nil
} }
if a, err = toConn(sockpair[0]); err != nil { // shadowing if a, err = toConn(sockpair[0]); err != nil { // shadowing
return nil, nil, err return nil, nil, err

View File

@ -0,0 +1,18 @@
package socketpair
import (
"testing"
"github.com/stretchr/testify/assert"
)
// This is test is mostly to verify that the assumption about
// net.FileConn returning *net.UnixConn for AF_UNIX FDs works.
func TestSocketPairWorks(t *testing.T) {
assert.NotPanics(t, func() {
a, b, err := SocketPair()
assert.NoError(t, err)
a.Close()
b.Close()
})
}

View File

@ -274,7 +274,7 @@ func ZFSCreatePlaceholderFilesystem(p *DatasetPath) (err error) {
} }
if err = cmd.Wait(); err != nil { if err = cmd.Wait(); err != nil {
err = ZFSError{ err = &ZFSError{
Stderr: stderr.Bytes(), Stderr: stderr.Bytes(),
WaitErr: err, WaitErr: err,
} }

Some files were not shown because too many files have changed in this diff Show More