diff --git a/.travis.yml b/.travis.yml index 8aafc51..8ced506 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,25 +16,6 @@ matrix: zrepl_build make vendordeps release # all go entries vary only by go version - - language: go - go: - - "1.10" - go_import_path: github.com/zrepl/zrepl - before_install: - - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip - - echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c - - sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip - - ./lazy.sh godep - - make vendordeps - script: - - make - - make vet - - make test - - go test ./... - - make artifacts/zrepl-freebsd-amd64 - - make artifacts/zrepl-linux-amd64 - - make artifacts/zrepl-darwin-amd64 - - language: go go: - "1.11" @@ -49,7 +30,24 @@ matrix: - make - make vet - make test - - go test ./... + - make artifacts/zrepl-freebsd-amd64 + - make artifacts/zrepl-linux-amd64 + - make artifacts/zrepl-darwin-amd64 + + - language: go + go: + - "master" + go_import_path: github.com/zrepl/zrepl + before_install: + - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip + - echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c + - sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip + - ./lazy.sh godep + - make vendordeps + script: + - make + - make vet + - make test - make artifacts/zrepl-freebsd-amd64 - make artifacts/zrepl-linux-amd64 - make artifacts/zrepl-darwin-amd64 diff --git a/Gopkg.lock b/Gopkg.lock index 6ad511b..4674a95 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -80,6 +80,10 @@ "protoc-gen-go/generator/internal/remap", "protoc-gen-go/grpc", "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/timestamp", ] pruneopts = "" revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" @@ -173,6 +177,14 @@ revision = "645ef00459ed84a119197bfb8d8205042c6df63d" version = "v0.8.0" +[[projects]] + digest = "1:1cbc6b98173422a756ae79e485952cb37a0a460c710541c75d3e9961c5a60782" + name = "github.com/pkg/profile" + packages = ["."] + pruneopts = "" + revision = "5b67d428864e92711fcbd2f8629456121a56d91f" + version = "v1.2.1" + [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" name = "github.com/pmezard/go-difflib" @@ -183,30 +195,11 @@ [[projects]] branch = "master" - digest = "1:25559b520313b941b1395cd5d5ee66086b27dc15a1391c0f2aad29d5c2321f4b" + digest = "1:fa72f780ae3b4820ed12cef7a034291ab10d83e2da4ab5ba81afa44d5bf3a529" name = "github.com/problame/go-netssh" packages = ["."] pruneopts = "" - revision = "c56ad38d2c91397ad3c8dd9443d7448e328a9e9e" - -[[projects]] - branch = "master" - digest = "1:8c63c44f018bd52b03ebad65c9df26aabbc6793138e421df1c8c84c285a45bc6" - name = "github.com/problame/go-rwccmd" - packages = ["."] - pruneopts = "" - revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7" - -[[projects]] - digest = "1:1bcbb0a7ad8d3392d446eb583ae5415ff987838a8f7331a36877789be20667e6" - name = "github.com/problame/go-streamrpc" - packages = [ - ".", - "internal/pdu", - ] - pruneopts = "" - revision = "d5d111e014342fe1c37f0b71cc37ec5f2afdfd13" - version = "v0.5" + revision = "09d6bc45d284784cb3e5aaa1998513f37eb19cc6" [[projects]] branch = "master" @@ -279,31 +272,66 @@ revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" version = "v1.1.4" +[[projects]] + digest = "1:f80053a92d9ac874ad97d665d3005c1118ed66e9e88401361dc32defe6bef21c" + name = "github.com/theckman/goconstraint" + packages = ["go1.11/gte"] + pruneopts = "" + revision = "93babf24513d0e8277635da8169fcc5a46ae3f6a" + version = "v1.11.0" + [[projects]] branch = "v2" - digest = "1:9d92186f609a73744232323416ddafd56fae67cb552162cc190ab903e36900dd" + digest = "1:6b8a6afafde7ed31cd0c577ba40d88ce39e8f1c5eb76d7836be7d5b74f1c534a" name = "github.com/zrepl/yaml-config" packages = ["."] pruneopts = "" - revision = "af27d27978ad95808723a62d87557d63c3ff0605" + revision = "08227ad854131f7dfcdfb12579fb73dd8a38a03a" [[projects]] branch = "master" - digest = "1:9c286cf11d0ca56368185bada5dd6d97b6be4648fc26c354fcba8df7293718f7" + digest = "1:ea539c13b066dac72a940b62f37600a20ab8e88057397c78f3197c1a48475425" + name = "golang.org/x/net" + packages = [ + "context", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "" + revision = "351d144fa1fc0bd934e2408202be0c29f25e35a0" + +[[projects]] + branch = "master" + digest = "1:f358024b019f87eecaadcb098113a40852c94fe58ea670ef3c3e2d2c7bd93db1" name = "golang.org/x/sys" packages = ["unix"] pruneopts = "" - revision = "bf42f188b9bc6f2cf5b8ee5a912ef1aedd0eba4c" + revision = "4ed8d59d0b35e1e29334a206d1b3f38b1e5dfb31" [[projects]] digest = "1:5acd3512b047305d49e8763eef7ba423901e85d5dd2fd1e71778a0ea8de10bd4" name = "golang.org/x/text" packages = [ + "collate", + "collate/build", "encoding", "encoding/internal/identifier", + "internal/colltab", "internal/gen", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", "transform", + "unicode/bidi", "unicode/cldr", + "unicode/norm", + "unicode/rangetable", ] pruneopts = "" revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" @@ -317,6 +345,54 @@ pruneopts = "" revision = "d0ca3933b724e6be513276cc2edb34e10d667438" +[[projects]] + branch = "master" + digest = "1:5fc6c317675b746d0c641b29aa0aab5fcb403c0d07afdbf0de86b0d447a0502a" + name = "google.golang.org/genproto" + packages = ["googleapis/rpc/status"] + pruneopts = "" + revision = "bd91e49a0898e27abb88c339b432fa53d7497ac0" + +[[projects]] + digest = "1:d141efe4aaad714e3059c340901aab3147b6253e58c85dafbcca3dd8b0e88ad6" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "stats", + "status", + "tap", + ] + pruneopts = "" + revision = "df014850f6dee74ba2fc94874043a9f3f75fbfd8" + version = "v1.17.0" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 @@ -331,9 +407,8 @@ "github.com/kr/pretty", "github.com/mattn/go-isatty", "github.com/pkg/errors", + "github.com/pkg/profile", "github.com/problame/go-netssh", - "github.com/problame/go-rwccmd", - "github.com/problame/go-streamrpc", "github.com/prometheus/client_golang/prometheus", "github.com/prometheus/client_golang/prometheus/promhttp", "github.com/spf13/cobra", @@ -341,8 +416,13 @@ "github.com/stretchr/testify/assert", "github.com/stretchr/testify/require", "github.com/zrepl/yaml-config", + "golang.org/x/net/context", "golang.org/x/sys/unix", "golang.org/x/tools/cmd/stringer", + "google.golang.org/grpc", + "google.golang.org/grpc/credentials", + "google.golang.org/grpc/keepalive", + "google.golang.org/grpc/peer", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 59ddabb..01e2aae 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,8 +1,11 @@ -ignored = [ "github.com/inconshreveable/mousetrap" ] +ignored = [ + "github.com/inconshreveable/mousetrap", +] -[[constraint]] - branch = "master" - name = "github.com/ftrvxmtrx/fd" +required = [ + "golang.org/x/tools/cmd/stringer", + "github.com/alvaroloes/enumer", +] [[constraint]] branch = "master" @@ -12,14 +15,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] branch = "master" name = "github.com/kr/pretty" -[[constraint]] - branch = "master" - name = "github.com/mitchellh/go-homedir" - -[[constraint]] - branch = "master" - name = "github.com/mitchellh/mapstructure" - [[constraint]] name = "github.com/pkg/errors" version = "0.8.0" @@ -28,10 +23,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] branch = "master" name = "github.com/spf13/cobra" -[[constraint]] - name = "github.com/spf13/viper" - version = "1.0.0" - [[constraint]] name = "github.com/stretchr/testify" version = "1.1.4" @@ -44,10 +35,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] name = "github.com/go-logfmt/logfmt" version = "*" -[[constraint]] - name = "github.com/problame/go-rwccmd" - branch = "master" - [[constraint]] name = "github.com/problame/go-netssh" branch = "master" @@ -58,26 +45,17 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] [[constraint]] name = "github.com/golang/protobuf" - version = "1.2.0" + version = "1" [[constraint]] name = "github.com/fatih/color" version = "1.7.0" -[[constraint]] - name = "github.com/problame/go-streamrpc" - version = "0.5.0" - - [[constraint]] name = "github.com/gdamore/tcell" version = "1.0.0" [[constraint]] - branch = "master" - name = "golang.org/x/tools" - -[[constraint]] - branch = "master" - name = "github.com/alvaroloes/enumer" + name = "google.golang.org/grpc" + version = "1" diff --git a/Makefile b/Makefile index a0f31a6..120eb28 100644 --- a/Makefile +++ b/Makefile @@ -1,45 +1,13 @@ .PHONY: generate build test vet cover release docs docs-clean clean vendordeps .DEFAULT_GOAL := build -ROOT := github.com/zrepl/zrepl -SUBPKGS += client -SUBPKGS += config -SUBPKGS += daemon -SUBPKGS += daemon/filters -SUBPKGS += daemon/job -SUBPKGS += daemon/logging -SUBPKGS += daemon/nethelpers -SUBPKGS += daemon/pruner -SUBPKGS += daemon/snapper -SUBPKGS += daemon/streamrpcconfig -SUBPKGS += daemon/transport -SUBPKGS += daemon/transport/connecter -SUBPKGS += daemon/transport/serve -SUBPKGS += endpoint -SUBPKGS += logger -SUBPKGS += pruning -SUBPKGS += pruning/retentiongrid -SUBPKGS += replication -SUBPKGS += replication/fsrep -SUBPKGS += replication/pdu -SUBPKGS += replication/internal/diff -SUBPKGS += tlsconf -SUBPKGS += util -SUBPKGS += util/socketpair -SUBPKGS += util/watchdog -SUBPKGS += util/envconst -SUBPKGS += version -SUBPKGS += zfs - -_TESTPKGS := $(ROOT) $(foreach p,$(SUBPKGS),$(ROOT)/$(p)) - ARTIFACTDIR := artifacts ifdef ZREPL_VERSION _ZREPL_VERSION := $(ZREPL_VERSION) endif ifndef _ZREPL_VERSION - _ZREPL_VERSION := $(shell git describe --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" ) + _ZREPL_VERSION := $(shell git describe --always --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" ) ifeq ($(_ZREPL_VERSION),ZREPL_BUILD_INVALID_VERSION) # can't use .SHELLSTATUS because Debian Stretch is still on gmake 4.1 $(error cannot infer variable ZREPL_VERSION using git and variable is not overriden by make invocation) endif @@ -59,35 +27,21 @@ vendordeps: dep ensure -v -vendor-only generate: #not part of the build, must do that manually - protoc -I=replication/pdu --go_out=replication/pdu replication/pdu/pdu.proto - @for pkg in $(_TESTPKGS); do\ - go generate "$$pkg" || exit 1; \ - done; + protoc -I=replication/pdu --go_out=plugins=grpc:replication/pdu replication/pdu/pdu.proto + go generate -x ./... build: @echo "INFO: In case of missing dependencies, run 'make vendordeps'" $(GO_BUILD) -o "$(ARTIFACTDIR)/zrepl" test: - @for pkg in $(_TESTPKGS); do \ - echo "Testing $$pkg"; \ - go test "$$pkg" || exit 1; \ - done; + go test ./... vet: - @for pkg in $(_TESTPKGS); do \ - echo "Vetting $$pkg"; \ - go vet "$$pkg" || exit 1; \ - done; - -cover: artifacts - @for pkg in $(_TESTPKGS); do \ - profile="$(ARTIFACTDIR)/cover-$$(basename $$pkg).out"; \ - go test -coverprofile "$$profile" $$pkg || exit 1; \ - if [ -f "$$profile" ]; then \ - go tool cover -html="$$profile" -o "$${profile}.html" || exit 2; \ - fi; \ - done; + # for each supported platform to cover conditional compilation + GOOS=linux go vet ./... + GOOS=darwin go vet ./... + GOOS=freebsd go vet ./... $(ARTIFACTDIR): mkdir -p "$@" @@ -132,7 +86,7 @@ release: $(RELEASE_BINS) $(RELEASE_NOARCH) cp $^ "$(ARTIFACTDIR)/release" cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt @# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override - @if git describe --dirty 2>/dev/null | grep dirty >/dev/null; then \ + @if git describe --always --dirty 2>/dev/null | grep dirty >/dev/null; then \ echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \ if [ "$(ZREPL_VERSION)" = "" ]; then \ echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \ diff --git a/config/config.go b/config/config.go index eeea023..3c2131a 100644 --- a/config/config.go +++ b/config/config.go @@ -130,7 +130,6 @@ type Global struct { Monitoring []MonitoringEnum `yaml:"monitoring,optional"` Control *GlobalControl `yaml:"control,optional,fromdefaults"` Serve *GlobalServe `yaml:"serve,optional,fromdefaults"` - RPC *RPCConfig `yaml:"rpc,optional,fromdefaults"` } func Default(i interface{}) { @@ -145,29 +144,18 @@ func Default(i interface{}) { } } -type RPCConfig struct { - Timeout time.Duration `yaml:"timeout,optional,positive,default=10s"` - TxChunkSize uint32 `yaml:"tx_chunk_size,optional,default=32768"` - RxStructuredMaxLen uint32 `yaml:"rx_structured_max,optional,default=16777216"` - RxStreamChunkMaxLen uint32 `yaml:"rx_stream_chunk_max,optional,default=16777216"` - RxHeaderMaxLen uint32 `yaml:"rx_header_max,optional,default=40960"` - SendHeartbeatInterval time.Duration `yaml:"send_heartbeat_interval,optional,positive,default=5s"` - -} - type ConnectEnum struct { Ret interface{} } type ConnectCommon struct { - Type string `yaml:"type"` - RPC *RPCConfig `yaml:"rpc,optional"` + Type string `yaml:"type"` } type TCPConnect struct { ConnectCommon `yaml:",inline"` Address string `yaml:"address"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type TLSConnect struct { @@ -177,7 +165,7 @@ type TLSConnect struct { Cert string `yaml:"cert"` Key string `yaml:"key"` ServerCN string `yaml:"server_cn"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type SSHStdinserverConnect struct { @@ -189,7 +177,7 @@ type SSHStdinserverConnect struct { TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused SSHCommand string `yaml:"ssh_command,optional"` //TODO unused Options []string `yaml:"options,optional"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type LocalConnect struct { @@ -203,8 +191,7 @@ type ServeEnum struct { } type ServeCommon struct { - Type string `yaml:"type"` - RPC *RPCConfig `yaml:"rpc,optional"` + Type string `yaml:"type"` } type TCPServe struct { @@ -220,7 +207,7 @@ type TLSServe struct { Cert string `yaml:"cert"` Key string `yaml:"key"` ClientCNs []string `yaml:"client_cns"` - HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"` + HandshakeTimeout time.Duration `yaml:"handshake_timeout,zeropositive,default=10s"` } type StdinserverServer struct { diff --git a/config/config_rpc_test.go b/config/config_rpc_test.go deleted file mode 100644 index f02311e..0000000 --- a/config/config_rpc_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package config - -import ( - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestRPC(t *testing.T) { - conf := testValidConfig(t, ` -jobs: -- name: pull_servers - type: pull - connect: - type: tcp - address: "server1.foo.bar:8888" - rpc: - timeout: 20s # different form default, should merge - root_fs: "pool2/backup_servers" - interval: 10m - pruning: - keep_sender: - - type: not_replicated - keep_receiver: - - type: last_n - count: 100 - -- name: pull_servers2 - type: pull - connect: - type: tcp - address: "server1.foo.bar:8888" - rpc: - tx_chunk_size: 0xabcd # different from default, should merge - root_fs: "pool2/backup_servers" - interval: 10m - pruning: - keep_sender: - - type: not_replicated - keep_receiver: - - type: last_n - count: 100 - -- type: sink - name: "laptop_sink" - root_fs: "pool2/backup_laptops" - serve: - type: tcp - listen: "192.168.122.189:8888" - clients: { - "10.23.42.23":"client1" - } - rpc: - rx_structured_max: 0x2342 - -- type: sink - name: "other_sink" - root_fs: "pool2/backup_laptops" - serve: - type: tcp - listen: "192.168.122.189:8888" - clients: { - "10.23.42.23":"client1" - } - rpc: - send_heartbeat_interval: 10s - -`) - - assert.Equal(t, 20*time.Second, conf.Jobs[0].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.Timeout) - assert.Equal(t, uint32(0xabcd), conf.Jobs[1].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.TxChunkSize) - assert.Equal(t, uint32(0x2342), conf.Jobs[2].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.RxStructuredMaxLen) - assert.Equal(t, 10*time.Second, conf.Jobs[3].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.SendHeartbeatInterval) - defConf := RPCConfig{} - Default(&defConf) - assert.Equal(t, defConf.Timeout, conf.Global.RPC.Timeout) -} - -func TestGlobal_DefaultRPCConfig(t *testing.T) { - assert.NotPanics(t, func() { - var c RPCConfig - Default(&c) - assert.NotNil(t, c) - assert.Equal(t, c.TxChunkSize, uint32(1)<<15) - }) -} diff --git a/config/config_test.go b/config/config_test.go index d8e1e2c..d51f3f7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,11 +1,14 @@ package config import ( - "github.com/kr/pretty" - "github.com/stretchr/testify/require" + "bytes" "path" "path/filepath" "testing" + "text/template" + + "github.com/kr/pretty" + "github.com/stretchr/testify/require" ) func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) { @@ -35,8 +38,21 @@ func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) { } +// template must be a template/text template with a single '{{ . }}' as placehodler for val +func testValidConfigTemplate(t *testing.T, tmpl string, val string) *Config { + tmp, err := template.New("master").Parse(tmpl) + if err != nil { + panic(err) + } + var buf bytes.Buffer + err = tmp.Execute(&buf, val) + if err != nil { + panic(err) + } + return testValidConfig(t, buf.String()) +} -func testValidConfig(t *testing.T, input string) (*Config) { +func testValidConfig(t *testing.T, input string) *Config { t.Helper() conf, err := testConfig(t, input) require.NoError(t, err) @@ -47,4 +63,4 @@ func testValidConfig(t *testing.T, input string) (*Config) { func testConfig(t *testing.T, input string) (*Config, error) { t.Helper() return ParseConfigBytes([]byte(input)) -} \ No newline at end of file +} diff --git a/daemon/job/active.go b/daemon/job/active.go index 766736b..907f4be 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -2,9 +2,12 @@ package job import ( "context" + "sync" + "time" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/job/reset" @@ -12,32 +15,30 @@ import ( "github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/pruner" "github.com/zrepl/zrepl/daemon/snapper" - "github.com/zrepl/zrepl/daemon/transport/connecter" "github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/fromconfig" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/zfs" - "sync" - "time" ) type ActiveSide struct { - mode activeMode - name string - clientFactory *connecter.ClientFactory + mode activeMode + name string + connecter transport.Connecter prunerFactory *pruner.PrunerFactory - - promRepStateSecs *prometheus.HistogramVec // labels: state - promPruneSecs *prometheus.HistogramVec // labels: prune_side - promBytesReplicated *prometheus.CounterVec // labels: filesystem + promRepStateSecs *prometheus.HistogramVec // labels: state + promPruneSecs *prometheus.HistogramVec // labels: prune_side + promBytesReplicated *prometheus.CounterVec // labels: filesystem tasksMtx sync.Mutex tasks activeSideTasks } - //go:generate enumer -type=ActiveSideState type ActiveSideState int @@ -48,12 +49,11 @@ const ( ActiveSideDone // also errors ) - type activeSideTasks struct { state ActiveSideState // valid for state ActiveSideReplicating, ActiveSidePruneSender, ActiveSidePruneReceiver, ActiveSideDone - replication *replication.Replication + replication *replication.Replication replicationCancel context.CancelFunc // valid for state ActiveSidePruneSender, ActiveSidePruneReceiver, ActiveSideDone @@ -77,28 +77,59 @@ func (a *ActiveSide) updateTasks(u func(*activeSideTasks)) activeSideTasks { } type activeMode interface { - SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) + ConnectEndpoints(rpcLoggers rpc.Loggers, connecter transport.Connecter) + DisconnectEndpoints() + SenderReceiver() (replication.Sender, replication.Receiver) Type() Type RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}) + ResetConnectBackoff() } type modePush struct { - fsfilter endpoint.FSFilter - snapper *snapper.PeriodicOrManual + setupMtx sync.Mutex + sender *endpoint.Sender + receiver *rpc.Client + fsfilter endpoint.FSFilter + snapper *snapper.PeriodicOrManual } -func (m *modePush) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { - sender := endpoint.NewSender(m.fsfilter) - receiver := endpoint.NewRemote(client) - return sender, receiver, nil +func (m *modePush) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil || m.sender != nil { + panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints") + } + m.sender = endpoint.NewSender(m.fsfilter) + m.receiver = rpc.NewClient(connecter, loggers) +} + +func (m *modePush) DisconnectEndpoints() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + m.receiver.Close() + m.sender = nil + m.receiver = nil +} + +func (m *modePush) SenderReceiver() (replication.Sender, replication.Receiver) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + return m.sender, m.receiver } func (m *modePush) Type() Type { return TypePush } -func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan <- struct{}) { +func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}) { m.snapper.Run(ctx, wakeUpCommon) } +func (m *modePush) ResetConnectBackoff() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil { + m.receiver.ResetConnectBackoff() + } +} func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) { m := &modePush{} @@ -116,14 +147,35 @@ func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) } type modePull struct { + setupMtx sync.Mutex + receiver *endpoint.Receiver + sender *rpc.Client rootFS *zfs.DatasetPath interval time.Duration } -func (m *modePull) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { - sender := endpoint.NewRemote(client) - receiver, err := endpoint.NewReceiver(m.rootFS) - return sender, receiver, err +func (m *modePull) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil || m.sender != nil { + panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints") + } + m.receiver = endpoint.NewReceiver(m.rootFS, false) + m.sender = rpc.NewClient(connecter, loggers) +} + +func (m *modePull) DisconnectEndpoints() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + m.sender.Close() + m.sender = nil + m.receiver = nil +} + +func (m *modePull) SenderReceiver() (replication.Sender, replication.Receiver) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + return m.sender, m.receiver } func (*modePull) Type() Type { return TypePull } @@ -148,6 +200,14 @@ func (m *modePull) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{} } } +func (m *modePull) ResetConnectBackoff() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.sender != nil { + m.sender.ResetConnectBackoff() + } +} + func modePullFromConfig(g *config.Global, in *config.PullJob) (m *modePull, err error) { m = &modePull{} if in.Interval <= 0 { @@ -175,17 +235,17 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act Subsystem: "replication", Name: "state_time", Help: "seconds spent during replication", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"state"}) j.promBytesReplicated = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "zrepl", Subsystem: "replication", Name: "bytes_replicated", Help: "number of bytes replicated from sender to receiver per filesystem", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"filesystem"}) - j.clientFactory, err = connecter.FromConfig(g, in.Connect) + j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect) if err != nil { return nil, errors.Wrap(err, "cannot build client") } @@ -195,7 +255,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act Subsystem: "pruning", Name: "time", Help: "seconds spent in pruner", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"prune_side"}) j.prunerFactory, err = pruner.NewPrunerFactory(in.Pruning, j.promPruneSecs) if err != nil { @@ -214,7 +274,7 @@ func (j *ActiveSide) RegisterMetrics(registerer prometheus.Registerer) { func (j *ActiveSide) Name() string { return j.name } type ActiveSideStatus struct { - Replication *replication.Report + Replication *replication.Report PruningSender, PruningReceiver *pruner.Report } @@ -256,6 +316,7 @@ outer: break outer case <-wakeup.Wait(ctx): + j.mode.ResetConnectBackoff() case <-periodicDone: } invocationCount++ @@ -268,6 +329,9 @@ func (j *ActiveSide) do(ctx context.Context) { log := GetLogger(ctx) ctx = logging.WithSubsystemLoggers(ctx, log) + loggers := rpc.GetLoggersOrPanic(ctx) // filled by WithSubsystemLoggers + j.mode.ConnectEndpoints(loggers, j.connecter) + defer j.mode.DisconnectEndpoints() // allow cancellation of an invocation (this function) ctx, cancelThisRun := context.WithCancel(ctx) @@ -353,13 +417,7 @@ func (j *ActiveSide) do(ctx context.Context) { } }() - client, err := j.clientFactory.NewClient() - if err != nil { - log.WithError(err).Error("factory cannot instantiate streamrpc client") - } - defer client.Close(ctx) - - sender, receiver, err := j.mode.SenderReceiver(client) + sender, receiver := j.mode.SenderReceiver() { select { diff --git a/daemon/job/passive.go b/daemon/job/passive.go index 99071a8..aae3f5a 100644 --- a/daemon/job/passive.go +++ b/daemon/job/passive.go @@ -2,28 +2,30 @@ package job import ( "context" + "fmt" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/logging" - "github.com/zrepl/zrepl/daemon/transport/serve" "github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/endpoint" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/fromconfig" "github.com/zrepl/zrepl/zfs" - "path" ) type PassiveSide struct { - mode passiveMode - name string - l serve.ListenerFactory - rpcConf *streamrpc.ConnConfig + mode passiveMode + name string + listen transport.AuthenticatedListenerFactory } type passiveMode interface { - ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc + Handler() rpc.Handler RunPeriodic(ctx context.Context) Type() Type } @@ -34,26 +36,8 @@ type modeSink struct { func (m *modeSink) Type() Type { return TypeSink } -func (m *modeSink) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { - log := GetLogger(ctx) - - clientRootStr := path.Join(m.rootDataset.ToString(), conn.ClientIdentity()) - clientRoot, err := zfs.NewDatasetPath(clientRootStr) - if err != nil { - log.WithError(err). - WithField("client_identity", conn.ClientIdentity()). - Error("cannot build client filesystem map (client identity must be a valid ZFS FS name") - } - log.WithField("client_root", clientRoot).Debug("client root") - - local, err := endpoint.NewReceiver(clientRoot) - if err != nil { - log.WithError(err).Error("unexpected error: cannot convert mapping to filter") - return nil - } - - h := endpoint.NewHandler(local) - return h.Handle +func (m *modeSink) Handler() rpc.Handler { + return endpoint.NewReceiver(m.rootDataset, true) } func (m *modeSink) RunPeriodic(_ context.Context) {} @@ -72,7 +56,7 @@ func modeSinkFromConfig(g *config.Global, in *config.SinkJob) (m *modeSink, err type modeSource struct { fsfilter zfs.DatasetFilter - snapper *snapper.PeriodicOrManual + snapper *snapper.PeriodicOrManual } func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource, err error) { @@ -93,10 +77,8 @@ func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource func (m *modeSource) Type() Type { return TypeSource } -func (m *modeSource) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { - sender := endpoint.NewSender(m.fsfilter) - h := endpoint.NewHandler(sender) - return h.Handle +func (m *modeSource) Handler() rpc.Handler { + return endpoint.NewSender(m.fsfilter) } func (m *modeSource) RunPeriodic(ctx context.Context) { @@ -106,8 +88,8 @@ func (m *modeSource) RunPeriodic(ctx context.Context) { func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passiveMode) (s *PassiveSide, err error) { s = &PassiveSide{mode: mode, name: in.Name} - if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil { - return nil, errors.Wrap(err, "cannot build server") + if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil { + return nil, errors.Wrap(err, "cannot build listener factory") } return s, nil @@ -115,7 +97,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passive func (j *PassiveSide) Name() string { return j.name } -type PassiveStatus struct {} +type PassiveStatus struct{} func (s *PassiveSide) Status() *Status { return &Status{Type: s.mode.Type()} // FIXME PassiveStatus @@ -127,70 +109,30 @@ func (j *PassiveSide) Run(ctx context.Context) { log := GetLogger(ctx) defer log.Info("job exiting") - - l, err := j.l.Listen() - if err != nil { - log.WithError(err).Error("cannot listen") - return - } - defer l.Close() - + ctx = logging.WithSubsystemLoggers(ctx, log) { - ctx, cancel := context.WithCancel(logging.WithSubsystemLoggers(ctx, log)) // shadowing + ctx, cancel := context.WithCancel(ctx) // shadowing defer cancel() go j.mode.RunPeriodic(ctx) } - log.WithField("addr", l.Addr()).Debug("accepting connections") - var connId int -outer: - for { - - select { - case res := <-accept(ctx, l): - if res.err != nil { - log.WithError(res.err).Info("accept error") - continue - } - conn := res.conn - connId++ - connLog := log. - WithField("connID", connId) - connLog. - WithField("addr", conn.RemoteAddr()). - WithField("client_identity", conn.ClientIdentity()). - Info("handling connection") - go func() { - defer connLog.Info("finished handling connection") - defer conn.Close() - ctx := logging.WithSubsystemLoggers(ctx, connLog) - handleFunc := j.mode.ConnHandleFunc(ctx, conn) - if handleFunc == nil { - return - } - if err := streamrpc.ServeConn(ctx, conn, j.rpcConf, handleFunc); err != nil { - log.WithError(err).Error("error serving client") - } - }() - - case <-ctx.Done(): - break outer - } - + handler := j.mode.Handler() + if handler == nil { + panic(fmt.Sprintf("implementation error: j.mode.Handler() returned nil: %#v", j)) } -} + ctxInterceptor := func(handlerCtx context.Context) context.Context { + return logging.WithSubsystemLoggers(handlerCtx, log) + } -type acceptResult struct { - conn serve.AuthenticatedConn - err error -} + rpcLoggers := rpc.GetLoggersOrPanic(ctx) // WithSubsystemLoggers above + server := rpc.NewServer(handler, rpcLoggers, ctxInterceptor) -func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult { - c := make(chan acceptResult, 1) - go func() { - conn, err := listener.Accept(ctx) - c <- acceptResult{conn, err} - }() - return c + listener, err := j.listen() + if err != nil { + log.WithError(err).Error("cannot listen") + return + } + + server.Serve(ctx, listener) } diff --git a/daemon/logging/adaptors.go b/daemon/logging/adaptors.go deleted file mode 100644 index c5a7196..0000000 --- a/daemon/logging/adaptors.go +++ /dev/null @@ -1,32 +0,0 @@ -package logging - -import ( - "fmt" - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/logger" - "strings" -) - -type streamrpcLogAdaptor = twoClassLogAdaptor - -type twoClassLogAdaptor struct { - logger.Logger -} - -var _ streamrpc.Logger = twoClassLogAdaptor{} - -func (a twoClassLogAdaptor) Errorf(fmtStr string, args ...interface{}) { - const errorSuffix = ": %s" - if len(args) == 1 { - if err, ok := args[0].(error); ok && strings.HasSuffix(fmtStr, errorSuffix) { - msg := strings.TrimSuffix(fmtStr, errorSuffix) - a.WithError(err).Error(msg) - return - } - } - a.Logger.Error(fmt.Sprintf(fmtStr, args...)) -} - -func (a twoClassLogAdaptor) Infof(fmtStr string, args ...interface{}) { - a.Logger.Debug(fmt.Sprintf(fmtStr, args...)) -} diff --git a/daemon/logging/build_logging.go b/daemon/logging/build_logging.go index fcc1fa4..bcefec5 100644 --- a/daemon/logging/build_logging.go +++ b/daemon/logging/build_logging.go @@ -4,18 +4,21 @@ import ( "context" "crypto/tls" "crypto/x509" + "os" + "github.com/mattn/go-isatty" "github.com/pkg/errors" - "github.com/problame/go-streamrpc" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/pruner" + "github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/rpc/transportmux" "github.com/zrepl/zrepl/tlsconf" - "os" - "github.com/zrepl/zrepl/daemon/snapper" - "github.com/zrepl/zrepl/daemon/transport/serve" + "github.com/zrepl/zrepl/transport" ) func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) { @@ -60,22 +63,41 @@ func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) } +type Subsystem string + const ( - SubsysReplication = "repl" - SubsysStreamrpc = "rpc" - SubsyEndpoint = "endpoint" + SubsysReplication Subsystem = "repl" + SubsyEndpoint Subsystem = "endpoint" + SubsysPruning Subsystem = "pruning" + SubsysSnapshot Subsystem = "snapshot" + SubsysTransport Subsystem = "transport" + SubsysTransportMux Subsystem = "transportmux" + SubsysRPC Subsystem = "rpc" + SubsysRPCControl Subsystem = "rpc.ctrl" + SubsysRPCData Subsystem = "rpc.data" ) func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Context { - ctx = replication.WithLogger(ctx, log.WithField(SubsysField, "repl")) - ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{log.WithField(SubsysField, "rpc")}) - ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint")) - ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning")) - ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot")) - ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve")) + ctx = replication.WithLogger(ctx, log.WithField(SubsysField, SubsysReplication)) + ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, SubsyEndpoint)) + ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, SubsysPruning)) + ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, SubsysSnapshot)) + ctx = transport.WithLogger(ctx, log.WithField(SubsysField, SubsysTransport)) + ctx = transportmux.WithLogger(ctx, log.WithField(SubsysField, SubsysTransportMux)) + ctx = rpc.WithLoggers(ctx, + rpc.Loggers{ + General: log.WithField(SubsysField, SubsysRPC), + Control: log.WithField(SubsysField, SubsysRPCControl), + Data: log.WithField(SubsysField, SubsysRPCData), + }, + ) return ctx } +func LogSubsystem(log logger.Logger, subsys Subsystem) logger.Logger { + return log.ReplaceField(SubsysField, subsys) +} + func parseLogFormat(i interface{}) (f EntryFormatter, err error) { var is string switch j := i.(type) { diff --git a/daemon/prometheus.go b/daemon/prometheus.go index 7607b94..1f10e9d 100644 --- a/daemon/prometheus.go +++ b/daemon/prometheus.go @@ -7,6 +7,7 @@ import ( "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/job" "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" "github.com/zrepl/zrepl/zfs" "net" "net/http" @@ -49,6 +50,10 @@ func (j *prometheusJob) Run(ctx context.Context) { panic(err) } + if err := frameconn.PrometheusRegister(prometheus.DefaultRegisterer); err != nil { + panic(err) + } + log := job.GetLogger(ctx) l, err := net.Listen("tcp", j.listen) diff --git a/daemon/pruner/pruner.go b/daemon/pruner/pruner.go index d5705db..bb515ad 100644 --- a/daemon/pruner/pruner.go +++ b/daemon/pruner/pruner.go @@ -11,7 +11,6 @@ import ( "github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/watchdog" - "github.com/problame/go-streamrpc" "net" "sort" "strings" @@ -19,14 +18,15 @@ import ( "time" ) -// Try to keep it compatible with gitub.com/zrepl/zrepl/replication.Endpoint +// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint type History interface { ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) } +// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint type Target interface { - ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) - ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS + ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) } @@ -346,7 +346,6 @@ type Error interface { } var _ Error = net.Error(nil) -var _ Error = streamrpc.Error(nil) func shouldRetry(e error) bool { if neterr, ok := e.(net.Error); ok { @@ -381,10 +380,11 @@ func statePlan(a *args, u updater) state { ka = &pruner.Progress }) - tfss, err := target.ListFilesystems(ctx) + tfssres, err := target.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { return onErr(u, err) } + tfss := tfssres.GetFilesystems() pfss := make([]*fs, len(tfss)) for i, tfs := range tfss { @@ -398,11 +398,12 @@ func statePlan(a *args, u updater) state { } pfss[i] = pfs - tfsvs, err := target.ListFilesystemVersions(ctx, tfs.Path) + tfsvsres, err := target.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: tfs.Path}) if err != nil { l.WithError(err).Error("cannot list filesystem versions") return onErr(u, err) } + tfsvs := tfsvsres.GetVersions() // no progress here since we could run in a live-lock (must have used target AND receiver before progress) pfs.snaps = make([]pruning.Snapshot, 0, len(tfsvs)) diff --git a/daemon/pruner/pruner_test.go b/daemon/pruner/pruner_test.go index 47d8f41..23a10e8 100644 --- a/daemon/pruner/pruner_test.go +++ b/daemon/pruner/pruner_test.go @@ -44,7 +44,7 @@ type mockTarget struct { destroyErrs map[string][]error } -func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { +func (t *mockTarget) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { if len(t.listFilesystemsErr) > 0 { e := t.listFilesystemsErr[0] t.listFilesystemsErr = t.listFilesystemsErr[1:] @@ -54,10 +54,11 @@ func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, er for i := range fss { fss[i] = t.fss[i].Filesystem() } - return fss, nil + return &pdu.ListFilesystemRes{Filesystems: fss}, nil } -func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { +func (t *mockTarget) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + fs := req.Filesystem if len(t.listVersionsErrs[fs]) != 0 { e := t.listVersionsErrs[fs][0] t.listVersionsErrs[fs] = t.listVersionsErrs[fs][1:] @@ -68,7 +69,7 @@ func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]* if mfs.path != fs { continue } - return mfs.FilesystemVersions(), nil + return &pdu.ListFilesystemVersionsRes{Versions: mfs.FilesystemVersions()}, nil } return nil, fmt.Errorf("filesystem %s does not exist", fs) } diff --git a/daemon/snapper/snapper.go b/daemon/snapper/snapper.go index 6cd5b98..c6cc9a8 100644 --- a/daemon/snapper/snapper.go +++ b/daemon/snapper/snapper.go @@ -177,7 +177,7 @@ func onMainCtxDone(ctx context.Context, u updater) state { } func syncUp(a args, u updater) state { - fss, err := listFSes(a.fsf) + fss, err := listFSes(a.ctx, a.fsf) if err != nil { return onErr(err, u) } @@ -204,7 +204,7 @@ func plan(a args, u updater) state { u(func(snapper *Snapper) { snapper.lastInvocation = time.Now() }) - fss, err := listFSes(a.fsf) + fss, err := listFSes(a.ctx, a.fsf) if err != nil { return onErr(err, u) } @@ -299,8 +299,8 @@ func wait(a args, u updater) state { } } -func listFSes(mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) { - return zfs.ZFSListMapping(mf) +func listFSes(ctx context.Context, mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) { + return zfs.ZFSListMapping(ctx, mf) } func findSyncPoint(log Logger, fss []*zfs.DatasetPath, prefix string, interval time.Duration) (syncPoint time.Time, err error) { diff --git a/daemon/streamrpcconfig/streamrpcconfig.go b/daemon/streamrpcconfig/streamrpcconfig.go deleted file mode 100644 index da28d5d..0000000 --- a/daemon/streamrpcconfig/streamrpcconfig.go +++ /dev/null @@ -1,25 +0,0 @@ -package streamrpcconfig - -import ( - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/config" -) - -func FromDaemonConfig(g *config.Global, in *config.RPCConfig) (*streamrpc.ConnConfig, error) { - conf := in - if conf == nil { - conf = g.RPC - } - srpcConf := &streamrpc.ConnConfig{ - RxHeaderMaxLen: conf.RxHeaderMaxLen, - RxStructuredMaxLen: conf.RxStructuredMaxLen, - RxStreamMaxChunkSize: conf.RxStreamChunkMaxLen, - TxChunkSize: conf.TxChunkSize, - Timeout: conf.Timeout, - SendHeartbeatInterval: conf.SendHeartbeatInterval, - } - if err := srpcConf.Validate(); err != nil { - return nil, err - } - return srpcConf, nil -} diff --git a/daemon/transport/connecter/connecter.go b/daemon/transport/connecter/connecter.go deleted file mode 100644 index fa772a7..0000000 --- a/daemon/transport/connecter/connecter.go +++ /dev/null @@ -1,84 +0,0 @@ -package connecter - -import ( - "context" - "fmt" - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/streamrpcconfig" - "github.com/zrepl/zrepl/daemon/transport" - "net" - "time" -) - - -type HandshakeConnecter struct { - connecter streamrpc.Connecter -} - -func (c HandshakeConnecter) Connect(ctx context.Context) (net.Conn, error) { - conn, err := c.connecter.Connect(ctx) - if err != nil { - return nil, err - } - dl, ok := ctx.Deadline() - if !ok { - dl = time.Now().Add(10 * time.Second) // FIXME constant - } - if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil { - conn.Close() - return nil, err - } - return conn, nil -} - - - -func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error) { - var ( - connecter streamrpc.Connecter - errConnecter, errRPC error - connConf *streamrpc.ConnConfig - ) - switch v := in.Ret.(type) { - case *config.SSHStdinserverConnect: - connecter, errConnecter = SSHStdinserverConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TCPConnect: - connecter, errConnecter = TCPConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TLSConnect: - connecter, errConnecter = TLSConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.LocalConnect: - connecter, errConnecter = LocalConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - default: - panic(fmt.Sprintf("implementation error: unknown connecter type %T", v)) - } - - if errConnecter != nil { - return nil, errConnecter - } - if errRPC != nil { - return nil, errRPC - } - - config := streamrpc.ClientConfig{ConnConfig: connConf} - if err := config.Validate(); err != nil { - return nil, err - } - - connecter = HandshakeConnecter{connecter} - - return &ClientFactory{connecter: connecter, config: &config}, nil -} - -type ClientFactory struct { - connecter streamrpc.Connecter - config *streamrpc.ClientConfig -} - -func (f ClientFactory) NewClient() (*streamrpc.Client, error) { - return streamrpc.NewClient(f.connecter, f.config) -} diff --git a/daemon/transport/handshake.go b/daemon/transport/handshake.go deleted file mode 100644 index ecfd495..0000000 --- a/daemon/transport/handshake.go +++ /dev/null @@ -1,136 +0,0 @@ -package transport - -import ( - "bytes" - "fmt" - "io" - "net" - "strings" - "time" - "unicode/utf8" -) - -type HandshakeMessage struct { - ProtocolVersion int - Extensions []string -} - -func (m *HandshakeMessage) Encode() ([]byte, error) { - if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 { - return nil, fmt.Errorf("protocol version must be in [1, 9999]") - } - if len(m.Extensions) >= 9999 { - return nil, fmt.Errorf("protocol only supports [0, 9999] extensions") - } - // EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions - var extensions strings.Builder - for i, ext := range m.Extensions { - if strings.ContainsAny(ext, "\n") { - return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i) - } - if !utf8.ValidString(ext) { - return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i) - } - extensions.WriteString(ext) - extensions.WriteString("\n") - } - withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s", - m.ProtocolVersion, len(m.Extensions), extensions.String()) - withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen) - return []byte(withLen), nil -} - -func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error { - var lenAndSpace [11]byte - if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil { - return err - } - if !utf8.Valid(lenAndSpace[:]) { - return fmt.Errorf("invalid start of handshake message: not valid UTF-8") - } - var followLen int - n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen) - if n != 1 || err != nil { - return fmt.Errorf("could not parse handshake message length") - } - if followLen > maxLen { - return fmt.Errorf("handshake message length exceeds max length (%d vs %d)", - followLen, maxLen) - } - - var buf bytes.Buffer - _, err = io.Copy(&buf, io.LimitReader(r, int64(followLen))) - if err != nil { - return err - } - - var ( - protoVersion, extensionCount int - ) - n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n", - &protoVersion, &extensionCount) - if n != 2 || err != nil { - return fmt.Errorf("could not parse handshake message: %s", err) - } - if protoVersion < 1 { - return fmt.Errorf("invalid protocol version %q", protoVersion) - } - m.ProtocolVersion = protoVersion - - if extensionCount < 0 { - return fmt.Errorf("invalid extension count %q", extensionCount) - } - if extensionCount == 0 { - if buf.Len() != 0 { - return fmt.Errorf("unexpected data trailing after header") - } - m.Extensions = nil - return nil - } - s := buf.String() - if strings.Count(s, "\n") != extensionCount { - return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount) - } - exts := strings.Split(s, "\n") - if exts[len(exts)-1] != "" { - return fmt.Errorf("unexpected data trailing after last extension newline") - } - m.Extensions = exts[0:len(exts)-1] - - return nil -} - -func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error { - // current protocol version is hardcoded here - return DoHandshakeVersion(conn, deadline, 1) -} - -func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error { - ours := HandshakeMessage{ - ProtocolVersion: version, - Extensions: nil, - } - hsb, err := ours.Encode() - if err != nil { - return fmt.Errorf("could not encode protocol banner: %s", err) - } - - conn.SetDeadline(deadline) - _, err = io.Copy(conn, bytes.NewBuffer(hsb)) - if err != nil { - return fmt.Errorf("could not send protocol banner: %s", err) - } - - theirs := HandshakeMessage{} - if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant - return fmt.Errorf("could not decode protocol banner: %s", err) - } - - if theirs.ProtocolVersion != ours.ProtocolVersion { - return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d", - ours.ProtocolVersion, theirs.ProtocolVersion) - } - // ignore extensions, we don't use them - - return nil -} diff --git a/daemon/transport/serve/serve.go b/daemon/transport/serve/serve.go deleted file mode 100644 index c1b3bb1..0000000 --- a/daemon/transport/serve/serve.go +++ /dev/null @@ -1,147 +0,0 @@ -package serve - -import ( - "github.com/pkg/errors" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/transport" - "net" - "github.com/zrepl/zrepl/daemon/streamrpcconfig" - "github.com/problame/go-streamrpc" - "context" - "github.com/zrepl/zrepl/logger" - "github.com/zrepl/zrepl/zfs" - "time" -) - -type contextKey int - -const contextKeyLog contextKey = 0 - -type Logger = logger.Logger - -func WithLogger(ctx context.Context, log Logger) context.Context { - return context.WithValue(ctx, contextKeyLog, log) -} - -func getLogger(ctx context.Context) Logger { - if log, ok := ctx.Value(contextKeyLog).(Logger); ok { - return log - } - return logger.NewNullLogger() -} - -type AuthenticatedConn interface { - net.Conn - // ClientIdentity must be a string that satisfies ValidateClientIdentity - ClientIdentity() string -} - -// A client identity must be a single component in a ZFS filesystem path -func ValidateClientIdentity(in string) (err error) { - path, err := zfs.NewDatasetPath(in) - if err != nil { - return err - } - if path.Length() != 1 { - return errors.New("client identity must be a single path comonent (not empty, no '/')") - } - return nil -} - -type authConn struct { - net.Conn - clientIdentity string -} - -var _ AuthenticatedConn = authConn{} - -func (c authConn) ClientIdentity() string { - if err := ValidateClientIdentity(c.clientIdentity); err != nil { - panic(err) - } - return c.clientIdentity -} - -// like net.Listener, but with an AuthenticatedConn instead of net.Conn -type AuthenticatedListener interface { - Addr() (net.Addr) - Accept(ctx context.Context) (AuthenticatedConn, error) - Close() error -} - -type ListenerFactory interface { - Listen() (AuthenticatedListener, error) -} - -type HandshakeListenerFactory struct { - lf ListenerFactory -} - -func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := lf.lf.Listen() - if err != nil { - return nil, err - } - return HandshakeListener{l}, nil -} - -type HandshakeListener struct { - l AuthenticatedListener -} - -func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() } - -func (l HandshakeListener) Close() error { return l.l.Close() } - -func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - conn, err := l.l.Accept(ctx) - if err != nil { - return nil, err - } - dl, ok := ctx.Deadline() - if !ok { - dl = time.Now().Add(10*time.Second) // FIXME constant - } - if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil { - conn.Close() - return nil, err - } - return conn, nil -} - -func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) { - - var ( - lfError, rpcErr error - ) - switch v := in.Ret.(type) { - case *config.TCPServe: - lf, lfError = TCPListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TLSServe: - lf, lfError = TLSListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.StdinserverServer: - lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.LocalServe: - lf, lfError = LocalListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - default: - return nil, nil, errors.Errorf("internal error: unknown serve type %T", v) - } - - if lfError != nil { - return nil, nil, lfError - } - if rpcErr != nil { - return nil, nil, rpcErr - } - - lf = HandshakeListenerFactory{lf} - - return lf, conf, nil - -} - - diff --git a/daemon/transport/serve/serve_tls.go b/daemon/transport/serve/serve_tls.go deleted file mode 100644 index bc95e41..0000000 --- a/daemon/transport/serve/serve_tls.go +++ /dev/null @@ -1,83 +0,0 @@ -package serve - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "github.com/pkg/errors" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/tlsconf" - "net" - "time" - "context" -) - -type TLSListenerFactory struct { - address string - clientCA *x509.CertPool - serverCert tls.Certificate - handshakeTimeout time.Duration - clientCNs map[string]struct{} -} - -func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TLSListenerFactory, err error) { - lf = &TLSListenerFactory{ - address: in.Listen, - handshakeTimeout: in.HandshakeTimeout, - } - - if in.Ca == "" || in.Cert == "" || in.Key == "" { - return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") - } - - lf.clientCA, err = tlsconf.ParseCAFile(in.Ca) - if err != nil { - return nil, errors.Wrap(err, "cannot parse ca file") - } - - lf.serverCert, err = tls.LoadX509KeyPair(in.Cert, in.Key) - if err != nil { - return nil, errors.Wrap(err, "cannot parse cer/key pair") - } - - lf.clientCNs = make(map[string]struct{}, len(in.ClientCNs)) - for i, cn := range in.ClientCNs { - if err := ValidateClientIdentity(cn); err != nil { - return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn) - } - // dupes are ok fr now - lf.clientCNs[cn] = struct{}{} - } - - return lf, nil -} - -func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := net.Listen("tcp", f.address) - if err != nil { - return nil, err - } - tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout) - return tlsAuthListener{tl, f.clientCNs}, nil -} - -type tlsAuthListener struct { - *tlsconf.ClientAuthListener - clientCNs map[string]struct{} -} - -func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - c, cn, err := l.ClientAuthListener.Accept() - if err != nil { - return nil, err - } - if _, ok := l.clientCNs[cn]; !ok { - if err := c.Close(); err != nil { - getLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name") - } - return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, c.RemoteAddr()) - } - return authConn{c, cn}, nil -} - - diff --git a/docs/configuration/transports.rst b/docs/configuration/transports.rst index abd5347..20d1b25 100644 --- a/docs/configuration/transports.rst +++ b/docs/configuration/transports.rst @@ -203,13 +203,17 @@ The serve & connect configuration will thus look like the following: ``ssh+stdinserver`` Transport ----------------------------- -``ssh+stdinserver`` is inspired by `git shell `_ and `Borg Backup `_. -It is provided by the Go package ``github.com/problame/go-netssh``. +``ssh+stdinserver`` uses the ``ssh`` command and some features of the server-side SSH ``authorized_keys`` file. +It is less efficient than other transports because the data passes through two more pipes. +However, it is fairly convenient to set up and allows the zrepl daemon to not be directly exposed to the internet, because all traffic passes through the system's SSH server. -.. ATTENTION:: +The concept is inspired by `git shell `_ and `Borg Backup `_. +The implementation is provided by the Go package ``github.com/problame/go-netssh``. - ``ssh+stdinserver`` has inferior error detection and handling compared to the ``tcp`` and ``tls`` transports. - If you require tested timeout & retry handling, use ``tcp`` or ``tls`` transports, or help improve package go-netssh. +.. NOTE:: + + ``ssh+stdinserver`` generally provides inferior error detection and handling compared to the ``tcp`` and ``tls`` transports. + When encountering such problems, consider using ``tcp`` or ``tls`` transports, or help improve package go-netssh. .. _transport-ssh+stdinserver-serve: diff --git a/endpoint/context.go b/endpoint/context.go index 09f9032..20b3296 100644 --- a/endpoint/context.go +++ b/endpoint/context.go @@ -9,6 +9,7 @@ type contextKey int const ( contextKeyLogger contextKey = iota + ClientIdentityKey ) type Logger = logger.Logger diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 4b44825..b90f3f7 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -2,21 +2,19 @@ package endpoint import ( - "bytes" "context" "fmt" - "github.com/golang/protobuf/proto" + "path" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/zrepl/zrepl/replication" "github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/zfs" - "io" ) // Sender implements replication.ReplicationEndpoint for a sending side type Sender struct { - FSFilter zfs.DatasetFilter + FSFilter zfs.DatasetFilter } func NewSender(fsf zfs.DatasetFilter) *Sender { @@ -41,8 +39,8 @@ func (s *Sender) filterCheckFS(fs string) (*zfs.DatasetPath, error) { return dp, nil } -func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - fss, err := zfs.ZFSListMapping(p.FSFilter) +func (s *Sender) ListFilesystems(ctx context.Context, r *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + fss, err := zfs.ZFSListMapping(ctx, s.FSFilter) if err != nil { return nil, err } @@ -53,11 +51,12 @@ func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) // FIXME: not supporting ResumeToken yet } } - return rfss, nil + res := &pdu.ListFilesystemRes{Filesystems: rfss, Empty: len(rfss) == 0} + return res, nil } -func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - lp, err := p.filterCheckFS(fs) +func (s *Sender) ListFilesystemVersions(ctx context.Context, r *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + lp, err := s.filterCheckFS(r.GetFilesystem()) if err != nil { return nil, err } @@ -69,32 +68,36 @@ func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu. for i := range fsvs { rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) } - return rfsvs, nil + res := &pdu.ListFilesystemVersionsRes{Versions: rfsvs} + return res, nil + } -func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - _, err := p.filterCheckFS(r.Filesystem) +func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + _, err := s.filterCheckFS(r.Filesystem) if err != nil { return nil, nil, err } - if r.DryRun { - si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "") - if err != nil { - return nil, nil, err - } - var expSize int64 = 0 // protocol says 0 means no estimate - if si.SizeEstimate != -1 { // but si returns -1 for no size estimate - expSize = si.SizeEstimate - } - return &pdu.SendRes{ExpectedSize: expSize}, nil, nil - } else { - stream, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "") - if err != nil { - return nil, nil, err - } - return &pdu.SendRes{}, stream, nil + si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "") + if err != nil { + return nil, nil, err } + var expSize int64 = 0 // protocol says 0 means no estimate + if si.SizeEstimate != -1 { // but si returns -1 for no size estimate + expSize = si.SizeEstimate + } + res := &pdu.SendRes{ExpectedSize: expSize} + + if r.DryRun { + return res, nil, nil + } + + streamCopier, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "") + if err != nil { + return nil, nil, err + } + return res, streamCopier, nil } func (p *Sender) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { @@ -132,6 +135,10 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs } } +func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { + return nil, fmt.Errorf("sender does not implement Receive()") +} + type FSFilter interface { // FIXME unused Filter(path *zfs.DatasetPath) (pass bool, err error) } @@ -146,14 +153,50 @@ type FSMap interface { // FIXME unused // Receiver implements replication.ReplicationEndpoint for a receiving side type Receiver struct { - root *zfs.DatasetPath + rootWithoutClientComponent *zfs.DatasetPath + appendClientIdentity bool } -func NewReceiver(rootDataset *zfs.DatasetPath) (*Receiver, error) { +func NewReceiver(rootDataset *zfs.DatasetPath, appendClientIdentity bool) *Receiver { if rootDataset.Length() <= 0 { - return nil, errors.New("root dataset must not be an empty path") + panic(fmt.Sprintf("root dataset must not be an empty path: %v", rootDataset)) } - return &Receiver{root: rootDataset.Copy()}, nil + return &Receiver{rootWithoutClientComponent: rootDataset.Copy(), appendClientIdentity: appendClientIdentity} +} + +func TestClientIdentity(rootFS *zfs.DatasetPath, clientIdentity string) error { + _, err := clientRoot(rootFS, clientIdentity) + return err +} + +func clientRoot(rootFS *zfs.DatasetPath, clientIdentity string) (*zfs.DatasetPath, error) { + rootFSLen := rootFS.Length() + clientRootStr := path.Join(rootFS.ToString(), clientIdentity) + clientRoot, err := zfs.NewDatasetPath(clientRootStr) + if err != nil { + return nil, err + } + if rootFSLen+1 != clientRoot.Length() { + return nil, fmt.Errorf("client identity must be a single ZFS filesystem path component") + } + return clientRoot, nil +} + +func (s *Receiver) clientRootFromCtx(ctx context.Context) *zfs.DatasetPath { + if !s.appendClientIdentity { + return s.rootWithoutClientComponent.Copy() + } + + clientIdentity, ok := ctx.Value(ClientIdentityKey).(string) + if !ok { + panic(fmt.Sprintf("ClientIdentityKey context value must be set")) + } + + clientRoot, err := clientRoot(s.rootWithoutClientComponent, clientIdentity) + if err != nil { + panic(fmt.Sprintf("ClientIdentityContextKey must have been validated before invoking Receiver: %s", err)) + } + return clientRoot } type subroot struct { @@ -180,8 +223,9 @@ func (f subroot) MapToLocal(fs string) (*zfs.DatasetPath, error) { return c, nil } -func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - filtered, err := zfs.ZFSListMapping(subroot{e.root}) +func (s *Receiver) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + root := s.clientRootFromCtx(ctx) + filtered, err := zfs.ZFSListMapping(ctx, subroot{root}) if err != nil { return nil, err } @@ -194,19 +238,30 @@ func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, erro WithError(err). WithField("fs", a). Error("inconsistent placeholder property") - return nil, errors.New("server error, see logs") // don't leak path + return nil, errors.New("server error: inconsistent placeholder property") // don't leak path } if ph { + getLogger(ctx). + WithField("fs", a.ToString()). + Debug("ignoring placeholder filesystem") continue } - a.TrimPrefix(e.root) + getLogger(ctx). + WithField("fs", a.ToString()). + Debug("non-placeholder filesystem") + a.TrimPrefix(root) fss = append(fss, &pdu.Filesystem{Path: a.ToString()}) } - return fss, nil + if len(fss) == 0 { + getLogger(ctx).Debug("no non-placeholder filesystems") + return &pdu.ListFilesystemRes{Empty: true}, nil + } + return &pdu.ListFilesystemRes{Filesystems: fss}, nil } -func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - lp, err := subroot{e.root}.MapToLocal(fs) +func (s *Receiver) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.GetFilesystem()) if err != nil { return nil, err } @@ -221,18 +276,26 @@ func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pd rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) } - return rfsvs, nil + return &pdu.ListFilesystemVersionsRes{Versions: rfsvs}, nil } -func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error { - defer sendStream.Close() +func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { + return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver") +} - lp, err := subroot{e.root}.MapToLocal(req.Filesystem) - if err != nil { - return err - } +func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + return nil, nil, fmt.Errorf("receiver does not implement Send()") +} +func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { getLogger(ctx).Debug("incoming Receive") + defer receive.Close() + + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.Filesystem) + if err != nil { + return nil, err + } // create placeholder parent filesystems as appropriate var visitErr error @@ -261,7 +324,7 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream getLogger(ctx).WithField("visitErr", visitErr).Debug("complete tree-walk") if visitErr != nil { - return visitErr + return nil, err } needForceRecv := false @@ -279,19 +342,19 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream getLogger(ctx).Debug("start receive command") - if err := zfs.ZFSRecv(ctx, lp.ToString(), sendStream, args...); err != nil { + if err := zfs.ZFSRecv(ctx, lp.ToString(), receive, args...); err != nil { getLogger(ctx). WithError(err). WithField("args", args). Error("zfs receive failed") - sendStream.Close() - return err + return nil, err } - return nil + return &pdu.ReceiveRes{}, nil } -func (e *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { - lp, err := subroot{e.root}.MapToLocal(req.Filesystem) +func (s *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.Filesystem) if err != nil { return nil, err } @@ -326,289 +389,3 @@ func doDestroySnapshots(ctx context.Context, lp *zfs.DatasetPath, snaps []*pdu.F } return res, nil } - -// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= -// RPC STUBS -// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= - -const ( - RPCListFilesystems = "ListFilesystems" - RPCListFilesystemVersions = "ListFilesystemVersions" - RPCReceive = "Receive" - RPCSend = "Send" - RPCSDestroySnapshots = "DestroySnapshots" - RPCReplicationCursor = "ReplicationCursor" -) - -// Remote implements an endpoint stub that uses streamrpc as a transport. -type Remote struct { - c *streamrpc.Client -} - -func NewRemote(c *streamrpc.Client) Remote { - return Remote{c} -} - -func (s Remote) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - req := pdu.ListFilesystemReq{} - b, err := proto.Marshal(&req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ListFilesystemRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return res.Filesystems, nil -} - -func (s Remote) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - req := pdu.ListFilesystemVersionsReq{ - Filesystem: fs, - } - b, err := proto.Marshal(&req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ListFilesystemVersionsRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return res.Versions, nil -} - -func (s Remote) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - b, err := proto.Marshal(r) - if err != nil { - return nil, nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil) - if err != nil { - return nil, nil, err - } - if !r.DryRun && rs == nil { - return nil, nil, errors.New("response does not contain a stream") - } - if r.DryRun && rs != nil { - rs.Close() - return nil, nil, errors.New("response contains unexpected stream (was dry run)") - } - var res pdu.SendRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - rs.Close() - return nil, nil, err - } - return &res, rs, nil -} - -func (s Remote) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error { - defer sendStream.Close() - b, err := proto.Marshal(r) - if err != nil { - return err - } - rb, rs, err := s.c.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream) - getLogger(ctx).WithField("err", err).Debug("Remote.Receive RequestReplyReturned") - if err != nil { - return err - } - if rs != nil { - rs.Close() - return errors.New("response contains unexpected stream") - } - var res pdu.ReceiveRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return err - } - return nil -} - -func (s Remote) DestroySnapshots(ctx context.Context, r *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { - b, err := proto.Marshal(r) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCSDestroySnapshots, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.DestroySnapshotsRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return &res, nil -} - -func (s Remote) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { - b, err := proto.Marshal(req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCReplicationCursor, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ReplicationCursorRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return &res, nil -} - -// Handler implements the server-side streamrpc.HandlerFunc for a Remote endpoint stub. -type Handler struct { - ep replication.Endpoint -} - -func NewHandler(ep replication.Endpoint) Handler { - return Handler{ep} -} - -func (a *Handler) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) { - - switch endpoint { - case RPCListFilesystems: - var req pdu.ListFilesystemReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - fsses, err := a.ep.ListFilesystems(ctx) - if err != nil { - return nil, nil, err - } - res := &pdu.ListFilesystemRes{ - Filesystems: fsses, - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCListFilesystemVersions: - - var req pdu.ListFilesystemVersionsReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem) - if err != nil { - return nil, nil, err - } - res := &pdu.ListFilesystemVersionsRes{ - Versions: fsvs, - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCSend: - - sender, ok := a.ep.(replication.Sender) - if !ok { - goto Err - } - - var req pdu.SendReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - res, sendStream, err := sender.Send(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), sendStream, err - - case RPCReceive: - - receiver, ok := a.ep.(replication.Receiver) - if !ok { - goto Err - } - - var req pdu.ReceiveReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - err := receiver.Receive(ctx, &req, reqStream) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(&pdu.ReceiveRes{}) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, err - - case RPCSDestroySnapshots: - - var req pdu.DestroySnapshotsReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - - res, err := a.ep.DestroySnapshots(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCReplicationCursor: - - sender, ok := a.ep.(replication.Sender) - if !ok { - goto Err - } - - var req pdu.ReplicationCursorReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - res, err := sender.ReplicationCursor(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - } -Err: - return nil, nil, errors.New("no handler for given endpoint") -} diff --git a/logger/stderrlogger.go b/logger/stderrlogger.go new file mode 100644 index 0000000..6ca0aad --- /dev/null +++ b/logger/stderrlogger.go @@ -0,0 +1,27 @@ +package logger + +import ( + "fmt" + "os" +) + +type stderrLogger struct { + Logger +} + +type stderrLoggerOutlet struct {} + +func (stderrLoggerOutlet) WriteEntry(entry Entry) error { + fmt.Fprintf(os.Stderr, "%#v\n", entry) + return nil +} + +var _ Logger = testLogger{} + +func NewStderrDebugLogger() Logger { + outlets := NewOutlets() + outlets.Add(&stderrLoggerOutlet{}, Debug) + return &testLogger{ + Logger: NewLogger(outlets, 0), + } +} diff --git a/replication/fsrep/fsfsm.go b/replication/fsrep/fsfsm.go index 6265d01..cfc8a5e 100644 --- a/replication/fsrep/fsfsm.go +++ b/replication/fsrep/fsfsm.go @@ -6,16 +6,16 @@ import ( "context" "errors" "fmt" - "github.com/prometheus/client_golang/prometheus" - "github.com/zrepl/zrepl/util/watchdog" - "io" "net" "sync" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/replication/pdu" - "github.com/zrepl/zrepl/util" + "github.com/zrepl/zrepl/util/bytecounter" + "github.com/zrepl/zrepl/util/watchdog" + "github.com/zrepl/zrepl/zfs" ) type contextKey int @@ -43,7 +43,7 @@ type Sender interface { // If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before // any next call to the parent github.com/zrepl/zrepl/replication.Endpoint. // If the send request is for dry run the io.ReadCloser will be nil - Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) } @@ -51,9 +51,7 @@ type Sender interface { type Receiver interface { // Receive sends r and sendStream (the latter containing a ZFS send stream) // to the parent github.com/zrepl/zrepl/replication.Endpoint. - // Implementors must guarantee that Close was called on sendStream before - // the call to Receive returns. - Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error + Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) } type StepReport struct { @@ -227,7 +225,7 @@ type ReplicationStep struct { // both retry and permanent error err error - byteCounter *util.ByteCounterReader + byteCounter bytecounter.StreamCopier expectedSize int64 // 0 means no size estimate present / possible } @@ -401,37 +399,54 @@ func (s *ReplicationStep) doReplication(ctx context.Context, ka *watchdog.KeepAl sr := s.buildSendRequest(false) log.Debug("initiate send request") - sres, sstream, err := sender.Send(ctx, sr) + sres, sstreamCopier, err := sender.Send(ctx, sr) if err != nil { log.WithError(err).Error("send request failed") return err } - if sstream == nil { + if sstreamCopier == nil { err := errors.New("send request did not return a stream, broken endpoint implementation") return err } + defer sstreamCopier.Close() - s.byteCounter = util.NewByteCounterReader(sstream) - s.byteCounter.SetCallback(1*time.Second, func(i int64) { - ka.MadeProgress() - }) - defer func() { - s.parent.promBytesReplicated.Add(float64(s.byteCounter.Bytes())) + // Install a byte counter to track progress + for status report + s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier) + byteCounterStopProgress := make(chan struct{}) + defer close(byteCounterStopProgress) + go func() { + var lastCount int64 + t := time.NewTicker(1 * time.Second) + defer t.Stop() + for { + select { + case <-byteCounterStopProgress: + return + case <-t.C: + newCount := s.byteCounter.Count() + if lastCount != newCount { + ka.MadeProgress() + } else { + lastCount = newCount + } + } + } + }() + defer func() { + s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count())) }() - sstream = s.byteCounter rr := &pdu.ReceiveReq{ Filesystem: fs, ClearResumeToken: !sres.UsedResumeToken, } log.Debug("initiate receive request") - err = receiver.Receive(ctx, rr, sstream) + _, err = receiver.Receive(ctx, rr, s.byteCounter) if err != nil { log. WithError(err). WithField("errType", fmt.Sprintf("%T", err)). Error("receive request failed (might also be error on sender)") - sstream.Close() // This failure could be due to // - an unexpected exit of ZFS on the sending side // - an unexpected exit of ZFS on the receiving side @@ -524,7 +539,7 @@ func (s *ReplicationStep) Report() *StepReport { } bytes := int64(0) if s.byteCounter != nil { - bytes = s.byteCounter.Bytes() + bytes = s.byteCounter.Count() } problem := "" if s.err != nil { diff --git a/replication/mainfsm.go b/replication/mainfsm.go index 2c8de45..5cf1d7b 100644 --- a/replication/mainfsm.go +++ b/replication/mainfsm.go @@ -10,7 +10,6 @@ import ( "github.com/zrepl/zrepl/daemon/job/wakeup" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/watchdog" - "github.com/problame/go-streamrpc" "math/bits" "net" "sort" @@ -106,9 +105,8 @@ func NewReplication(secsPerState *prometheus.HistogramVec, bytesReplicated *prom // named interfaces defined in this package. type Endpoint interface { // Does not include placeholder filesystems - ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) - // FIXME document FilteredError handling - ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS + ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) } @@ -203,7 +201,6 @@ type Error interface { var _ Error = fsrep.Error(nil) var _ Error = net.Error(nil) -var _ Error = streamrpc.Error(nil) func isPermanent(err error) bool { if e, ok := err.(Error); ok { @@ -232,19 +229,20 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r }).rsf() } - sfss, err := sender.ListFilesystems(ctx) + slfssres, err := sender.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { - log.WithError(err).Error("error listing sender filesystems") + log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing sender filesystems") return handlePlanningError(err) } + sfss := slfssres.GetFilesystems() // no progress here since we could run in a live-lock on connectivity issues - rfss, err := receiver.ListFilesystems(ctx) + rlfssres, err := receiver.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { - log.WithError(err).Error("error listing receiver filesystems") + log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing receiver filesystems") return handlePlanningError(err) } - + rfss := rlfssres.GetFilesystems() ka.MadeProgress() // for both sender and receiver q := make([]*fsrep.Replication, 0, len(sfss)) @@ -255,11 +253,12 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r log.Debug("assessing filesystem") - sfsvs, err := sender.ListFilesystemVersions(ctx, fs.Path) + sfsvsres, err := sender.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path}) if err != nil { log.WithError(err).Error("cannot get remote filesystem versions") return handlePlanningError(err) } + sfsvs := sfsvsres.GetVersions() ka.MadeProgress() if len(sfsvs) < 1 { @@ -278,7 +277,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r var rfsvs []*pdu.FilesystemVersion if receiverFSExists { - rfsvs, err = receiver.ListFilesystemVersions(ctx, fs.Path) + rfsvsres, err := receiver.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path}) if err != nil { if _, ok := err.(*FilteredError); ok { log.Info("receiver ignores filesystem") @@ -287,6 +286,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r log.WithError(err).Error("receiver error") return handlePlanningError(err) } + rfsvs = rfsvsres.GetVersions() } else { rfsvs = []*pdu.FilesystemVersion{} } diff --git a/replication/pdu/pdu.pb.go b/replication/pdu/pdu.pb.go index f2b7d37..6b7fd86 100644 --- a/replication/pdu/pdu.pb.go +++ b/replication/pdu/pdu.pb.go @@ -7,6 +7,11 @@ import proto "github.com/golang/protobuf/proto" import fmt "fmt" import math "math" +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf @@ -38,7 +43,7 @@ func (x FilesystemVersion_VersionType) String() string { return proto.EnumName(FilesystemVersion_VersionType_name, int32(x)) } func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5, 0} + return fileDescriptor_pdu_89315d819a6e0938, []int{5, 0} } type ListFilesystemReq struct { @@ -51,7 +56,7 @@ func (m *ListFilesystemReq) Reset() { *m = ListFilesystemReq{} } func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemReq) ProtoMessage() {} func (*ListFilesystemReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{0} + return fileDescriptor_pdu_89315d819a6e0938, []int{0} } func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b) @@ -73,6 +78,7 @@ var xxx_messageInfo_ListFilesystemReq proto.InternalMessageInfo type ListFilesystemRes struct { Filesystems []*Filesystem `protobuf:"bytes,1,rep,name=Filesystems,proto3" json:"Filesystems,omitempty"` + Empty bool `protobuf:"varint,2,opt,name=Empty,proto3" json:"Empty,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -82,7 +88,7 @@ func (m *ListFilesystemRes) Reset() { *m = ListFilesystemRes{} } func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemRes) ProtoMessage() {} func (*ListFilesystemRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{1} + return fileDescriptor_pdu_89315d819a6e0938, []int{1} } func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b) @@ -109,6 +115,13 @@ func (m *ListFilesystemRes) GetFilesystems() []*Filesystem { return nil } +func (m *ListFilesystemRes) GetEmpty() bool { + if m != nil { + return m.Empty + } + return false +} + type Filesystem struct { Path string `protobuf:"bytes,1,opt,name=Path,proto3" json:"Path,omitempty"` ResumeToken string `protobuf:"bytes,2,opt,name=ResumeToken,proto3" json:"ResumeToken,omitempty"` @@ -121,7 +134,7 @@ func (m *Filesystem) Reset() { *m = Filesystem{} } func (m *Filesystem) String() string { return proto.CompactTextString(m) } func (*Filesystem) ProtoMessage() {} func (*Filesystem) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{2} + return fileDescriptor_pdu_89315d819a6e0938, []int{2} } func (m *Filesystem) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Filesystem.Unmarshal(m, b) @@ -166,7 +179,7 @@ func (m *ListFilesystemVersionsReq) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsReq) ProtoMessage() {} func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{3} + return fileDescriptor_pdu_89315d819a6e0938, []int{3} } func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b) @@ -204,7 +217,7 @@ func (m *ListFilesystemVersionsRes) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsRes) ProtoMessage() {} func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{4} + return fileDescriptor_pdu_89315d819a6e0938, []int{4} } func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b) @@ -232,7 +245,7 @@ func (m *ListFilesystemVersionsRes) GetVersions() []*FilesystemVersion { } type FilesystemVersion struct { - Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=pdu.FilesystemVersion_VersionType" json:"Type,omitempty"` + Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=FilesystemVersion_VersionType" json:"Type,omitempty"` Name string `protobuf:"bytes,2,opt,name=Name,proto3" json:"Name,omitempty"` Guid uint64 `protobuf:"varint,3,opt,name=Guid,proto3" json:"Guid,omitempty"` CreateTXG uint64 `protobuf:"varint,4,opt,name=CreateTXG,proto3" json:"CreateTXG,omitempty"` @@ -246,7 +259,7 @@ func (m *FilesystemVersion) Reset() { *m = FilesystemVersion{} } func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) } func (*FilesystemVersion) ProtoMessage() {} func (*FilesystemVersion) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5} + return fileDescriptor_pdu_89315d819a6e0938, []int{5} } func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b) @@ -326,7 +339,7 @@ func (m *SendReq) Reset() { *m = SendReq{} } func (m *SendReq) String() string { return proto.CompactTextString(m) } func (*SendReq) ProtoMessage() {} func (*SendReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{6} + return fileDescriptor_pdu_89315d819a6e0938, []int{6} } func (m *SendReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendReq.Unmarshal(m, b) @@ -407,7 +420,7 @@ func (m *Property) Reset() { *m = Property{} } func (m *Property) String() string { return proto.CompactTextString(m) } func (*Property) ProtoMessage() {} func (*Property) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{7} + return fileDescriptor_pdu_89315d819a6e0938, []int{7} } func (m *Property) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Property.Unmarshal(m, b) @@ -443,11 +456,11 @@ func (m *Property) GetValue() string { type SendRes struct { // Whether the resume token provided in the request has been used or not. - UsedResumeToken bool `protobuf:"varint,1,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"` + UsedResumeToken bool `protobuf:"varint,2,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"` // Expected stream size determined by dry run, not exact. // 0 indicates that for the given SendReq, no size estimate could be made. - ExpectedSize int64 `protobuf:"varint,2,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"` - Properties []*Property `protobuf:"bytes,3,rep,name=Properties,proto3" json:"Properties,omitempty"` + ExpectedSize int64 `protobuf:"varint,3,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"` + Properties []*Property `protobuf:"bytes,4,rep,name=Properties,proto3" json:"Properties,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -457,7 +470,7 @@ func (m *SendRes) Reset() { *m = SendRes{} } func (m *SendRes) String() string { return proto.CompactTextString(m) } func (*SendRes) ProtoMessage() {} func (*SendRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{8} + return fileDescriptor_pdu_89315d819a6e0938, []int{8} } func (m *SendRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendRes.Unmarshal(m, b) @@ -511,7 +524,7 @@ func (m *ReceiveReq) Reset() { *m = ReceiveReq{} } func (m *ReceiveReq) String() string { return proto.CompactTextString(m) } func (*ReceiveReq) ProtoMessage() {} func (*ReceiveReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{9} + return fileDescriptor_pdu_89315d819a6e0938, []int{9} } func (m *ReceiveReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveReq.Unmarshal(m, b) @@ -555,7 +568,7 @@ func (m *ReceiveRes) Reset() { *m = ReceiveRes{} } func (m *ReceiveRes) String() string { return proto.CompactTextString(m) } func (*ReceiveRes) ProtoMessage() {} func (*ReceiveRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{10} + return fileDescriptor_pdu_89315d819a6e0938, []int{10} } func (m *ReceiveRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveRes.Unmarshal(m, b) @@ -588,7 +601,7 @@ func (m *DestroySnapshotsReq) Reset() { *m = DestroySnapshotsReq{} } func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsReq) ProtoMessage() {} func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{11} + return fileDescriptor_pdu_89315d819a6e0938, []int{11} } func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b) @@ -634,7 +647,7 @@ func (m *DestroySnapshotRes) Reset() { *m = DestroySnapshotRes{} } func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotRes) ProtoMessage() {} func (*DestroySnapshotRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{12} + return fileDescriptor_pdu_89315d819a6e0938, []int{12} } func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b) @@ -679,7 +692,7 @@ func (m *DestroySnapshotsRes) Reset() { *m = DestroySnapshotsRes{} } func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsRes) ProtoMessage() {} func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{13} + return fileDescriptor_pdu_89315d819a6e0938, []int{13} } func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b) @@ -721,7 +734,7 @@ func (m *ReplicationCursorReq) Reset() { *m = ReplicationCursorReq{} } func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq) ProtoMessage() {} func (*ReplicationCursorReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14} + return fileDescriptor_pdu_89315d819a6e0938, []int{14} } func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b) @@ -869,7 +882,7 @@ func (m *ReplicationCursorReq_GetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_GetOp) ProtoMessage() {} func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 0} + return fileDescriptor_pdu_89315d819a6e0938, []int{14, 0} } func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b) @@ -900,7 +913,7 @@ func (m *ReplicationCursorReq_SetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_SetOp) ProtoMessage() {} func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 1} + return fileDescriptor_pdu_89315d819a6e0938, []int{14, 1} } func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b) @@ -941,7 +954,7 @@ func (m *ReplicationCursorRes) Reset() { *m = ReplicationCursorRes{} } func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorRes) ProtoMessage() {} func (*ReplicationCursorRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{15} + return fileDescriptor_pdu_89315d819a6e0938, []int{15} } func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b) @@ -1067,71 +1080,246 @@ func _ReplicationCursorRes_OneofSizer(msg proto.Message) (n int) { } func init() { - proto.RegisterType((*ListFilesystemReq)(nil), "pdu.ListFilesystemReq") - proto.RegisterType((*ListFilesystemRes)(nil), "pdu.ListFilesystemRes") - proto.RegisterType((*Filesystem)(nil), "pdu.Filesystem") - proto.RegisterType((*ListFilesystemVersionsReq)(nil), "pdu.ListFilesystemVersionsReq") - proto.RegisterType((*ListFilesystemVersionsRes)(nil), "pdu.ListFilesystemVersionsRes") - proto.RegisterType((*FilesystemVersion)(nil), "pdu.FilesystemVersion") - proto.RegisterType((*SendReq)(nil), "pdu.SendReq") - proto.RegisterType((*Property)(nil), "pdu.Property") - proto.RegisterType((*SendRes)(nil), "pdu.SendRes") - proto.RegisterType((*ReceiveReq)(nil), "pdu.ReceiveReq") - proto.RegisterType((*ReceiveRes)(nil), "pdu.ReceiveRes") - proto.RegisterType((*DestroySnapshotsReq)(nil), "pdu.DestroySnapshotsReq") - proto.RegisterType((*DestroySnapshotRes)(nil), "pdu.DestroySnapshotRes") - proto.RegisterType((*DestroySnapshotsRes)(nil), "pdu.DestroySnapshotsRes") - proto.RegisterType((*ReplicationCursorReq)(nil), "pdu.ReplicationCursorReq") - proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "pdu.ReplicationCursorReq.GetOp") - proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "pdu.ReplicationCursorReq.SetOp") - proto.RegisterType((*ReplicationCursorRes)(nil), "pdu.ReplicationCursorRes") - proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) + proto.RegisterType((*ListFilesystemReq)(nil), "ListFilesystemReq") + proto.RegisterType((*ListFilesystemRes)(nil), "ListFilesystemRes") + proto.RegisterType((*Filesystem)(nil), "Filesystem") + proto.RegisterType((*ListFilesystemVersionsReq)(nil), "ListFilesystemVersionsReq") + proto.RegisterType((*ListFilesystemVersionsRes)(nil), "ListFilesystemVersionsRes") + proto.RegisterType((*FilesystemVersion)(nil), "FilesystemVersion") + proto.RegisterType((*SendReq)(nil), "SendReq") + proto.RegisterType((*Property)(nil), "Property") + proto.RegisterType((*SendRes)(nil), "SendRes") + proto.RegisterType((*ReceiveReq)(nil), "ReceiveReq") + proto.RegisterType((*ReceiveRes)(nil), "ReceiveRes") + proto.RegisterType((*DestroySnapshotsReq)(nil), "DestroySnapshotsReq") + proto.RegisterType((*DestroySnapshotRes)(nil), "DestroySnapshotRes") + proto.RegisterType((*DestroySnapshotsRes)(nil), "DestroySnapshotsRes") + proto.RegisterType((*ReplicationCursorReq)(nil), "ReplicationCursorReq") + proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "ReplicationCursorReq.GetOp") + proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "ReplicationCursorReq.SetOp") + proto.RegisterType((*ReplicationCursorRes)(nil), "ReplicationCursorRes") + proto.RegisterEnum("FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) } -func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_fe566e6b212fcf8d) } +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn -var fileDescriptor_pdu_fe566e6b212fcf8d = []byte{ - // 659 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdb, 0x6e, 0x13, 0x31, - 0x10, 0xcd, 0xe6, 0xba, 0x99, 0x94, 0x5e, 0xdc, 0xaa, 0x2c, 0x15, 0x82, 0xc8, 0xbc, 0x04, 0x24, - 0x22, 0x91, 0x56, 0xbc, 0xf0, 0x96, 0xde, 0xf2, 0x80, 0xda, 0xca, 0x09, 0x55, 0x9f, 0x90, 0x42, - 0x77, 0x44, 0x57, 0xb9, 0x78, 0x6b, 0x7b, 0x51, 0xc3, 0x07, 0xf0, 0x4f, 0xfc, 0x07, 0x0f, 0x7c, - 0x0e, 0xf2, 0xec, 0x25, 0xdb, 0x24, 0x54, 0x79, 0x8a, 0xcf, 0xf8, 0x78, 0xe6, 0xcc, 0xf1, 0x8e, - 0x03, 0xf5, 0xd0, 0x8f, 0xda, 0xa1, 0x92, 0x46, 0xb2, 0x52, 0xe8, 0x47, 0x7c, 0x17, 0x76, 0x3e, - 0x07, 0xda, 0x9c, 0x05, 0x63, 0xd4, 0x33, 0x6d, 0x70, 0x22, 0xf0, 0x9e, 0x9f, 0x2d, 0x07, 0x35, - 0xfb, 0x00, 0x8d, 0x79, 0x40, 0x7b, 0x4e, 0xb3, 0xd4, 0x6a, 0x74, 0xb6, 0xda, 0x36, 0x5f, 0x8e, - 0x98, 0xe7, 0xf0, 0x2e, 0xc0, 0x1c, 0x32, 0x06, 0xe5, 0xab, 0xa1, 0xb9, 0xf3, 0x9c, 0xa6, 0xd3, - 0xaa, 0x0b, 0x5a, 0xb3, 0x26, 0x34, 0x04, 0xea, 0x68, 0x82, 0x03, 0x39, 0xc2, 0xa9, 0x57, 0xa4, - 0xad, 0x7c, 0x88, 0x7f, 0x82, 0x17, 0x8f, 0xb5, 0x5c, 0xa3, 0xd2, 0x81, 0x9c, 0x6a, 0x81, 0xf7, - 0xec, 0x55, 0xbe, 0x40, 0x92, 0x38, 0x17, 0xe1, 0x97, 0xff, 0x3f, 0xac, 0x59, 0x07, 0xdc, 0x14, - 0x26, 0xdd, 0xec, 0x2f, 0x74, 0x93, 0x6c, 0x8b, 0x8c, 0xc7, 0xff, 0x3a, 0xb0, 0xb3, 0xb4, 0xcf, - 0x3e, 0x42, 0x79, 0x30, 0x0b, 0x91, 0x04, 0x6c, 0x76, 0xf8, 0xea, 0x2c, 0xed, 0xe4, 0xd7, 0x32, - 0x05, 0xf1, 0xad, 0x23, 0x17, 0xc3, 0x09, 0x26, 0x6d, 0xd3, 0xda, 0xc6, 0xce, 0xa3, 0xc0, 0xf7, - 0x4a, 0x4d, 0xa7, 0x55, 0x16, 0xb4, 0x66, 0x2f, 0xa1, 0x7e, 0xac, 0x70, 0x68, 0x70, 0x70, 0x73, - 0xee, 0x95, 0x69, 0x63, 0x1e, 0x60, 0x07, 0xe0, 0x12, 0x08, 0xe4, 0xd4, 0xab, 0x50, 0xa6, 0x0c, - 0xf3, 0xb7, 0xd0, 0xc8, 0x95, 0x65, 0x1b, 0xe0, 0xf6, 0xa7, 0xc3, 0x50, 0xdf, 0x49, 0xb3, 0x5d, - 0xb0, 0xa8, 0x2b, 0xe5, 0x68, 0x32, 0x54, 0xa3, 0x6d, 0x87, 0xff, 0x76, 0xa0, 0xd6, 0xc7, 0xa9, - 0xbf, 0x86, 0xaf, 0x56, 0xe4, 0x99, 0x92, 0x93, 0x54, 0xb8, 0x5d, 0xb3, 0x4d, 0x28, 0x0e, 0x24, - 0xc9, 0xae, 0x8b, 0xe2, 0x40, 0x2e, 0x5e, 0x6d, 0x79, 0xe9, 0x6a, 0x49, 0xb8, 0x9c, 0x84, 0x0a, - 0xb5, 0x26, 0xe1, 0xae, 0xc8, 0x30, 0xdb, 0x83, 0xca, 0x09, 0xfa, 0x51, 0xe8, 0x55, 0x69, 0x23, - 0x06, 0x6c, 0x1f, 0xaa, 0x27, 0x6a, 0x26, 0xa2, 0xa9, 0x57, 0xa3, 0x70, 0x82, 0xf8, 0x11, 0xb8, - 0x57, 0x4a, 0x86, 0xa8, 0xcc, 0x2c, 0x33, 0xd5, 0xc9, 0x99, 0xba, 0x07, 0x95, 0xeb, 0xe1, 0x38, - 0x4a, 0x9d, 0x8e, 0x01, 0xff, 0x95, 0x75, 0xac, 0x59, 0x0b, 0xb6, 0xbe, 0x68, 0xf4, 0xf3, 0x8a, - 0x1d, 0x2a, 0xb1, 0x18, 0x66, 0x1c, 0x36, 0x4e, 0x1f, 0x42, 0xbc, 0x35, 0xe8, 0xf7, 0x83, 0x9f, - 0x71, 0xca, 0x92, 0x78, 0x14, 0x63, 0xef, 0x01, 0x12, 0x3d, 0x01, 0x6a, 0xaf, 0x44, 0x1f, 0xd7, - 0x33, 0xfa, 0x2c, 0x52, 0x99, 0x22, 0x47, 0xe0, 0x37, 0x00, 0x02, 0x6f, 0x31, 0xf8, 0x81, 0xeb, - 0x98, 0xff, 0x0e, 0xb6, 0x8f, 0xc7, 0x38, 0x54, 0x8b, 0x83, 0xe3, 0x8a, 0xa5, 0x38, 0xdf, 0xc8, - 0x65, 0xd6, 0x7c, 0x04, 0xbb, 0x27, 0xa8, 0x8d, 0x92, 0xb3, 0xf4, 0x2b, 0x58, 0x67, 0x8a, 0xd8, - 0x11, 0xd4, 0x33, 0xbe, 0x57, 0x7c, 0x72, 0x52, 0xe6, 0x44, 0xfe, 0x15, 0xd8, 0x42, 0xb1, 0x64, - 0xe8, 0x52, 0x48, 0x95, 0x9e, 0x18, 0xba, 0x94, 0x67, 0x6f, 0xef, 0x54, 0x29, 0xa9, 0xd2, 0xdb, - 0x23, 0xc0, 0x7b, 0xab, 0x9a, 0xb1, 0xcf, 0x54, 0xcd, 0x1a, 0x30, 0x36, 0xe9, 0x50, 0x3f, 0xa7, - 0xfc, 0xcb, 0x52, 0x44, 0xca, 0xe3, 0x7f, 0x1c, 0xd8, 0x13, 0x18, 0x8e, 0x83, 0x5b, 0x1a, 0x9a, - 0xe3, 0x48, 0x69, 0xa9, 0xd6, 0x31, 0xe6, 0x10, 0x4a, 0xdf, 0xd1, 0x90, 0xac, 0x46, 0xe7, 0x35, - 0xd5, 0x59, 0x95, 0xa7, 0x7d, 0x8e, 0xe6, 0x32, 0xec, 0x15, 0x84, 0x65, 0xdb, 0x43, 0x1a, 0x0d, - 0x0d, 0xca, 0x93, 0x87, 0xfa, 0xe9, 0x21, 0x8d, 0xe6, 0xa0, 0x06, 0x15, 0x4a, 0x72, 0xf0, 0x06, - 0x2a, 0xb4, 0x61, 0x87, 0x27, 0x33, 0x32, 0xf6, 0x25, 0xc3, 0xdd, 0x32, 0x14, 0x65, 0xc8, 0x07, - 0x2b, 0xbb, 0xb2, 0xa3, 0x15, 0xbf, 0x30, 0xb6, 0x9f, 0x72, 0xaf, 0x90, 0xbd, 0x31, 0xee, 0x85, - 0x34, 0xf8, 0x10, 0xe8, 0x38, 0x9f, 0xdb, 0x2b, 0x88, 0x2c, 0xd2, 0x75, 0xa1, 0x1a, 0xbb, 0xf5, - 0xad, 0x4a, 0x7f, 0x1e, 0x87, 0xff, 0x02, 0x00, 0x00, 0xff, 0xff, 0x66, 0x74, 0x36, 0x3a, 0x49, - 0x06, 0x00, 0x00, +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// ReplicationClient is the client API for Replication service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type ReplicationClient interface { + ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) + DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) + ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) +} + +type replicationClient struct { + cc *grpc.ClientConn +} + +func NewReplicationClient(cc *grpc.ClientConn) ReplicationClient { + return &replicationClient{cc} +} + +func (c *replicationClient) ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) { + out := new(ListFilesystemRes) + err := c.cc.Invoke(ctx, "/Replication/ListFilesystems", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) { + out := new(ListFilesystemVersionsRes) + err := c.cc.Invoke(ctx, "/Replication/ListFilesystemVersions", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) { + out := new(DestroySnapshotsRes) + err := c.cc.Invoke(ctx, "/Replication/DestroySnapshots", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) { + out := new(ReplicationCursorRes) + err := c.cc.Invoke(ctx, "/Replication/ReplicationCursor", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ReplicationServer is the server API for Replication service. +type ReplicationServer interface { + ListFilesystems(context.Context, *ListFilesystemReq) (*ListFilesystemRes, error) + ListFilesystemVersions(context.Context, *ListFilesystemVersionsReq) (*ListFilesystemVersionsRes, error) + DestroySnapshots(context.Context, *DestroySnapshotsReq) (*DestroySnapshotsRes, error) + ReplicationCursor(context.Context, *ReplicationCursorReq) (*ReplicationCursorRes, error) +} + +func RegisterReplicationServer(s *grpc.Server, srv ReplicationServer) { + s.RegisterService(&_Replication_serviceDesc, srv) +} + +func _Replication_ListFilesystems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListFilesystemReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ListFilesystems(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ListFilesystems", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ListFilesystems(ctx, req.(*ListFilesystemReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_ListFilesystemVersions_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListFilesystemVersionsReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ListFilesystemVersions(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ListFilesystemVersions", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ListFilesystemVersions(ctx, req.(*ListFilesystemVersionsReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_DestroySnapshots_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DestroySnapshotsReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).DestroySnapshots(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/DestroySnapshots", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).DestroySnapshots(ctx, req.(*DestroySnapshotsReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_ReplicationCursor_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReplicationCursorReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ReplicationCursor(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ReplicationCursor", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ReplicationCursor(ctx, req.(*ReplicationCursorReq)) + } + return interceptor(ctx, in, info, handler) +} + +var _Replication_serviceDesc = grpc.ServiceDesc{ + ServiceName: "Replication", + HandlerType: (*ReplicationServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ListFilesystems", + Handler: _Replication_ListFilesystems_Handler, + }, + { + MethodName: "ListFilesystemVersions", + Handler: _Replication_ListFilesystemVersions_Handler, + }, + { + MethodName: "DestroySnapshots", + Handler: _Replication_DestroySnapshots_Handler, + }, + { + MethodName: "ReplicationCursor", + Handler: _Replication_ReplicationCursor_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "pdu.proto", +} + +func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_89315d819a6e0938) } + +var fileDescriptor_pdu_89315d819a6e0938 = []byte{ + // 735 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdd, 0x6e, 0xda, 0x4a, + 0x10, 0xc6, 0x60, 0xc0, 0x0c, 0x51, 0x42, 0x36, 0x9c, 0xc8, 0xc7, 0xe7, 0x28, 0x42, 0xdb, 0x1b, + 0x52, 0xa9, 0x6e, 0x45, 0x7b, 0x53, 0x55, 0xaa, 0x54, 0x42, 0x7e, 0xa4, 0x56, 0x69, 0xb4, 0xd0, + 0x28, 0xca, 0x1d, 0x0d, 0xa3, 0xc4, 0x0a, 0xb0, 0xce, 0xee, 0xba, 0x0a, 0xbd, 0xec, 0x7b, 0xf4, + 0x41, 0xfa, 0x0e, 0xbd, 0xec, 0x03, 0x55, 0xbb, 0x60, 0xe3, 0x60, 0x23, 0x71, 0xe5, 0xfd, 0xbe, + 0x9d, 0x9d, 0x9d, 0xf9, 0x76, 0x66, 0x0c, 0xb5, 0x70, 0x14, 0xf9, 0xa1, 0xe0, 0x8a, 0xd3, 0x3d, + 0xd8, 0xfd, 0x14, 0x48, 0x75, 0x12, 0x8c, 0x51, 0xce, 0xa4, 0xc2, 0x09, 0xc3, 0x07, 0x7a, 0x95, + 0x25, 0x25, 0x79, 0x01, 0xf5, 0x25, 0x21, 0x5d, 0xab, 0x55, 0x6a, 0xd7, 0x3b, 0x75, 0x3f, 0x65, + 0x94, 0xde, 0x27, 0x4d, 0x28, 0x1f, 0x4f, 0x42, 0x35, 0x73, 0x8b, 0x2d, 0xab, 0xed, 0xb0, 0x39, + 0xa0, 0x5d, 0x80, 0xa5, 0x11, 0x21, 0x60, 0x5f, 0x0c, 0xd5, 0x9d, 0x6b, 0xb5, 0xac, 0x76, 0x8d, + 0x99, 0x35, 0x69, 0x41, 0x9d, 0xa1, 0x8c, 0x26, 0x38, 0xe0, 0xf7, 0x38, 0x35, 0xa7, 0x6b, 0x2c, + 0x4d, 0xd1, 0x77, 0xf0, 0xef, 0xd3, 0xe8, 0x2e, 0x51, 0xc8, 0x80, 0x4f, 0x25, 0xc3, 0x07, 0x72, + 0x90, 0xbe, 0x60, 0xe1, 0x38, 0xc5, 0xd0, 0x8f, 0xeb, 0x0f, 0x4b, 0xe2, 0x83, 0x13, 0xc3, 0x45, + 0x7e, 0xc4, 0xcf, 0x58, 0xb2, 0xc4, 0x86, 0xfe, 0xb1, 0x60, 0x37, 0xb3, 0x4f, 0x3a, 0x60, 0x0f, + 0x66, 0x21, 0x9a, 0xcb, 0xb7, 0x3b, 0x07, 0x59, 0x0f, 0xfe, 0xe2, 0xab, 0xad, 0x98, 0xb1, 0xd5, + 0x4a, 0x9c, 0x0f, 0x27, 0xb8, 0x48, 0xd7, 0xac, 0x35, 0x77, 0x1a, 0x05, 0x23, 0xb7, 0xd4, 0xb2, + 0xda, 0x36, 0x33, 0x6b, 0xf2, 0x3f, 0xd4, 0x8e, 0x04, 0x0e, 0x15, 0x0e, 0xae, 0x4e, 0x5d, 0xdb, + 0x6c, 0x2c, 0x09, 0xe2, 0x81, 0x63, 0x40, 0xc0, 0xa7, 0x6e, 0xd9, 0x78, 0x4a, 0x30, 0x3d, 0x84, + 0x7a, 0xea, 0x5a, 0xb2, 0x05, 0x4e, 0x7f, 0x3a, 0x0c, 0xe5, 0x1d, 0x57, 0x8d, 0x82, 0x46, 0x5d, + 0xce, 0xef, 0x27, 0x43, 0x71, 0xdf, 0xb0, 0xe8, 0x2f, 0x0b, 0xaa, 0x7d, 0x9c, 0x8e, 0x36, 0xd0, + 0x53, 0x07, 0x79, 0x22, 0xf8, 0x24, 0x0e, 0x5c, 0xaf, 0xc9, 0x36, 0x14, 0x07, 0xdc, 0x84, 0x5d, + 0x63, 0xc5, 0x01, 0x5f, 0x7d, 0x52, 0x3b, 0xf3, 0xa4, 0x26, 0x70, 0x3e, 0x09, 0x05, 0x4a, 0x69, + 0x02, 0x77, 0x58, 0x82, 0x75, 0x21, 0xf5, 0x70, 0x14, 0x85, 0x6e, 0x65, 0x5e, 0x48, 0x06, 0x90, + 0x7d, 0xa8, 0xf4, 0xc4, 0x8c, 0x45, 0x53, 0xb7, 0x6a, 0xe8, 0x05, 0xa2, 0x6f, 0xc0, 0xb9, 0x10, + 0x3c, 0x44, 0xa1, 0x66, 0x89, 0xa8, 0x56, 0x4a, 0xd4, 0x26, 0x94, 0x2f, 0x87, 0xe3, 0x28, 0x56, + 0x7a, 0x0e, 0xe8, 0x8f, 0x24, 0x63, 0x49, 0xda, 0xb0, 0xf3, 0x45, 0xe2, 0x68, 0xb5, 0x08, 0x1d, + 0xb6, 0x4a, 0x13, 0x0a, 0x5b, 0xc7, 0x8f, 0x21, 0xde, 0x28, 0x1c, 0xf5, 0x83, 0xef, 0x68, 0x32, + 0x2e, 0xb1, 0x27, 0x1c, 0x39, 0x04, 0x58, 0xc4, 0x13, 0xa0, 0x74, 0x6d, 0x53, 0x54, 0x35, 0x3f, + 0x0e, 0x91, 0xa5, 0x36, 0xe9, 0x15, 0x00, 0xc3, 0x1b, 0x0c, 0xbe, 0xe1, 0x26, 0xc2, 0x3f, 0x87, + 0xc6, 0xd1, 0x18, 0x87, 0x22, 0x1b, 0x67, 0x86, 0xa7, 0x5b, 0x29, 0xcf, 0x92, 0xde, 0xc2, 0x5e, + 0x0f, 0xa5, 0x12, 0x7c, 0x16, 0x57, 0xc0, 0x26, 0x9d, 0x43, 0x5e, 0x41, 0x2d, 0xb1, 0x77, 0x8b, + 0x6b, 0xbb, 0x63, 0x69, 0x44, 0xaf, 0x81, 0xac, 0x5c, 0xb4, 0x68, 0xb2, 0x18, 0x9a, 0x5b, 0xd6, + 0x34, 0x59, 0x6c, 0x63, 0x06, 0x89, 0x10, 0x5c, 0xc4, 0x2f, 0x66, 0x00, 0xed, 0xe5, 0x25, 0xa1, + 0x87, 0x54, 0x55, 0x27, 0x3e, 0x56, 0x71, 0x03, 0xef, 0xf9, 0xd9, 0x10, 0x58, 0x6c, 0x43, 0x7f, + 0x5b, 0xd0, 0x64, 0x18, 0x8e, 0x83, 0x1b, 0xd3, 0x24, 0x47, 0x91, 0x90, 0x5c, 0x6c, 0x22, 0xc6, + 0x4b, 0x28, 0xdd, 0xa2, 0x32, 0x21, 0xd5, 0x3b, 0xff, 0xf9, 0x79, 0x3e, 0xfc, 0x53, 0x54, 0x9f, + 0xc3, 0xb3, 0x02, 0xd3, 0x96, 0xfa, 0x80, 0x44, 0x65, 0x4a, 0x64, 0xed, 0x81, 0x7e, 0x7c, 0x40, + 0xa2, 0xf2, 0xaa, 0x50, 0x36, 0x0e, 0xbc, 0x67, 0x50, 0x36, 0x1b, 0xba, 0x49, 0x12, 0xe1, 0xe6, + 0x5a, 0x24, 0xb8, 0x6b, 0x43, 0x91, 0x87, 0x74, 0x90, 0x9b, 0x8d, 0x6e, 0xa1, 0xf9, 0x24, 0xd1, + 0x79, 0xd8, 0x67, 0x85, 0x64, 0x96, 0x38, 0xe7, 0x5c, 0xe1, 0x63, 0x20, 0xe7, 0xfe, 0x9c, 0xb3, + 0x02, 0x4b, 0x98, 0xae, 0x03, 0x95, 0xb9, 0x4a, 0x9d, 0x9f, 0x45, 0xdd, 0xbf, 0x89, 0x5b, 0xf2, + 0x16, 0x76, 0x9e, 0x8e, 0x50, 0x49, 0x88, 0x9f, 0xf9, 0x89, 0x78, 0x59, 0x4e, 0x92, 0x0b, 0xd8, + 0xcf, 0x9f, 0xbe, 0xc4, 0xf3, 0xd7, 0xce, 0x74, 0x6f, 0xfd, 0x9e, 0x24, 0xef, 0xa1, 0xb1, 0x5a, + 0x07, 0xa4, 0xe9, 0xe7, 0xd4, 0xb7, 0x97, 0xc7, 0x4a, 0xf2, 0x01, 0x76, 0x33, 0x92, 0x91, 0x7f, + 0x72, 0xdf, 0xc7, 0xcb, 0xa5, 0x65, 0xb7, 0x7c, 0x5d, 0x0a, 0x47, 0xd1, 0xd7, 0x8a, 0xf9, 0xa1, + 0xbe, 0xfe, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xba, 0x8e, 0x63, 0x5d, 0x07, 0x00, 0x00, } diff --git a/replication/pdu/pdu.proto b/replication/pdu/pdu.proto index 6d9430a..1b66916 100644 --- a/replication/pdu/pdu.proto +++ b/replication/pdu/pdu.proto @@ -1,11 +1,19 @@ syntax = "proto3"; +option go_package = "pdu"; -package pdu; +service Replication { + rpc ListFilesystems (ListFilesystemReq) returns (ListFilesystemRes); + rpc ListFilesystemVersions (ListFilesystemVersionsReq) returns (ListFilesystemVersionsRes); + rpc DestroySnapshots (DestroySnapshotsReq) returns (DestroySnapshotsRes); + rpc ReplicationCursor (ReplicationCursorReq) returns (ReplicationCursorRes); + // for Send and Recv, see package rpc +} message ListFilesystemReq {} message ListFilesystemRes { repeated Filesystem Filesystems = 1; + bool Empty = 2; } message Filesystem { @@ -60,22 +68,18 @@ message Property { } message SendRes { - // The actual stream is in the stream part of the streamrpc response - // Whether the resume token provided in the request has been used or not. - bool UsedResumeToken = 1; + bool UsedResumeToken = 2; // Expected stream size determined by dry run, not exact. // 0 indicates that for the given SendReq, no size estimate could be made. - int64 ExpectedSize = 2; + int64 ExpectedSize = 3; - repeated Property Properties = 3; + repeated Property Properties = 4; } message ReceiveReq { - // The stream part of the streamrpc request contains the zfs send stream - - string Filesystem = 1; + string Filesystem = 1; // FIXME should be snapshot name, we can enforce that on recv // If true, the receiver should clear the resume token before perfoming the zfs recv of the stream in the request bool ClearResumeToken = 2; diff --git a/rpc/dataconn/base2bufpool/base2bufpool.go b/rpc/dataconn/base2bufpool/base2bufpool.go new file mode 100644 index 0000000..73efecc --- /dev/null +++ b/rpc/dataconn/base2bufpool/base2bufpool.go @@ -0,0 +1,169 @@ +package base2bufpool + +import ( + "fmt" + "math/bits" + "sync" +) + +type pool struct { + mtx sync.Mutex + bufs [][]byte + shift uint +} + +func (p *pool) Put(buf []byte) { + p.mtx.Lock() + defer p.mtx.Unlock() + if len(buf) != 1< 10 { // FIXME constant + return + } + p.bufs = append(p.bufs, buf) +} + +func (p *pool) Get() []byte { + p.mtx.Lock() + defer p.mtx.Unlock() + if len(p.bufs) > 0 { + ret := p.bufs[len(p.bufs)-1] + p.bufs = p.bufs[0 : len(p.bufs)-1] + return ret + } + return make([]byte, 1< b.payloadLen { + panic(fmt.Sprintf("shrink is actually an expand, invalid: %v %v", newPayloadLen, b.payloadLen)) + } + b.payloadLen = newPayloadLen +} + +func (b Buffer) Free() { + if b.pool != nil { + b.pool.put(b) + } +} + +//go:generate enumer -type NoFitBehavior +type NoFitBehavior uint + +const ( + AllocateSmaller NoFitBehavior = 1 << iota + AllocateLarger + + Allocate NoFitBehavior = AllocateSmaller | AllocateLarger + Panic NoFitBehavior = 0 +) + +func New(minShift, maxShift uint, noFitBehavior NoFitBehavior) *Pool { + if minShift > 63 || maxShift > 63 { + panic(fmt.Sprintf("{min|max}Shift are the _exponent_, got minShift=%v maxShift=%v and limit of 63, which amounts to %v bits", minShift, maxShift, uint64(1)<<63)) + } + pools := make([]pool, maxShift-minShift+1) + for i := uint(0); i < uint(len(pools)); i++ { + i := i // the closure below must copy i + pools[i] = pool{ + shift: minShift + i, + bufs: make([][]byte, 0, 10), + } + } + return &Pool{ + minShift: minShift, + maxShift: maxShift, + pools: pools, + onNoFit: noFitBehavior, + } +} + +func fittingShift(x uint) uint { + if x == 0 { + return 0 + } + blen := uint(bits.Len(x)) + if 1<<(blen-1) == x { + return blen - 1 + } + return blen +} + +func (p *Pool) handlePotentialNoFit(reqShift uint) (buf Buffer, didHandle bool) { + if reqShift == 0 { + if p.onNoFit&AllocateSmaller != 0 { + return Buffer{[]byte{}, 0, nil}, true + } else { + goto doPanic + } + } + if reqShift < p.minShift { + if p.onNoFit&AllocateSmaller != 0 { + goto alloc + } else { + goto doPanic + } + } + if reqShift > p.maxShift { + if p.onNoFit&AllocateLarger != 0 { + goto alloc + } else { + goto doPanic + } + } + return Buffer{}, false +alloc: + return Buffer{make([]byte, 1< 1 { + panic(fmt.Sprintf("putting buffer that is not power of two len: %v", len(buf))) + } + if len(buf) == 0 { + return + } + shift := fittingShift(uint(len(buf))) + if shift < p.minShift || shift > p.maxShift { + return // drop it + } + p.pools[shift-p.minShift].Put(buf) +} diff --git a/rpc/dataconn/base2bufpool/base2bufpool_test.go b/rpc/dataconn/base2bufpool/base2bufpool_test.go new file mode 100644 index 0000000..d6ce361 --- /dev/null +++ b/rpc/dataconn/base2bufpool/base2bufpool_test.go @@ -0,0 +1,98 @@ +package base2bufpool + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPoolAllocBehavior(t *testing.T) { + + type testcase struct { + poolMinShift, poolMaxShift uint + behavior NoFitBehavior + get uint + expShiftBufLen int64 // -1 if panic expected + } + + tcs := []testcase{ + { + 15, 20, Allocate, + 1 << 14, 1 << 14, + }, + { + 15, 20, Allocate, + 1 << 22, 1 << 22, + }, + { + 15, 20, Panic, + 1 << 16, 1 << 16, + }, + { + 15, 20, Panic, + 1 << 14, -1, + }, + { + 15, 20, Panic, + 1 << 22, -1, + }, + { + 15, 20, Panic, + (1 << 15) + 23, 1 << 16, + }, + { + 15, 20, Panic, + 0, -1, // yep, 0 always works, even + }, + { + 15, 20, Allocate, + 0, 0, + }, + { + 15, 20, AllocateSmaller, + 1 << 14, 1 << 14, + }, + { + 15, 20, AllocateSmaller, + 1 << 22, -1, + }, + } + + for i := range tcs { + tc := tcs[i] + t.Run(fmt.Sprintf("[%d,%d] behav=%s Get(%d) exp=%d", tc.poolMinShift, tc.poolMaxShift, tc.behavior, tc.get, tc.expShiftBufLen), func(t *testing.T) { + pool := New(tc.poolMinShift, tc.poolMaxShift, tc.behavior) + if tc.expShiftBufLen == -1 { + assert.Panics(t, func() { + pool.Get(tc.get) + }) + return + } + buf := pool.Get(tc.get) + assert.True(t, uint(len(buf.Bytes())) == tc.get) + assert.True(t, int64(len(buf.shiftBuf)) == tc.expShiftBufLen) + }) + } +} + +func TestFittingShift(t *testing.T) { + assert.Equal(t, uint(16), fittingShift(1+1<<15)) + assert.Equal(t, uint(15), fittingShift(1<<15)) +} + +func TestFreeFromPoolRangeDoesNotPanic(t *testing.T) { + pool := New(15, 20, Allocate) + buf := pool.Get(1 << 16) + assert.NotPanics(t, func() { + buf.Free() + }) +} + +func TestFreeFromOutOfPoolRangeDoesNotPanic(t *testing.T) { + pool := New(15, 20, Allocate) + buf := pool.Get(1 << 23) + assert.NotPanics(t, func() { + buf.Free() + }) +} diff --git a/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go b/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go new file mode 100644 index 0000000..18f70b8 --- /dev/null +++ b/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go @@ -0,0 +1,51 @@ +// Code generated by "enumer -type NoFitBehavior"; DO NOT EDIT. + +package base2bufpool + +import ( + "fmt" +) + +const _NoFitBehaviorName = "PanicAllocateSmallerAllocateLargerAllocate" + +var _NoFitBehaviorIndex = [...]uint8{0, 5, 20, 34, 42} + +func (i NoFitBehavior) String() string { + if i >= NoFitBehavior(len(_NoFitBehaviorIndex)-1) { + return fmt.Sprintf("NoFitBehavior(%d)", i) + } + return _NoFitBehaviorName[_NoFitBehaviorIndex[i]:_NoFitBehaviorIndex[i+1]] +} + +var _NoFitBehaviorValues = []NoFitBehavior{0, 1, 2, 3} + +var _NoFitBehaviorNameToValueMap = map[string]NoFitBehavior{ + _NoFitBehaviorName[0:5]: 0, + _NoFitBehaviorName[5:20]: 1, + _NoFitBehaviorName[20:34]: 2, + _NoFitBehaviorName[34:42]: 3, +} + +// NoFitBehaviorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func NoFitBehaviorString(s string) (NoFitBehavior, error) { + if val, ok := _NoFitBehaviorNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to NoFitBehavior values", s) +} + +// NoFitBehaviorValues returns all values of the enum +func NoFitBehaviorValues() []NoFitBehavior { + return _NoFitBehaviorValues +} + +// IsANoFitBehavior returns "true" if the value is listed in the enum definition. "false" otherwise +func (i NoFitBehavior) IsANoFitBehavior() bool { + for _, v := range _NoFitBehaviorValues { + if i == v { + return true + } + } + return false +} diff --git a/rpc/dataconn/dataconn_client.go b/rpc/dataconn/dataconn_client.go new file mode 100644 index 0000000..a12292b --- /dev/null +++ b/rpc/dataconn/dataconn_client.go @@ -0,0 +1,215 @@ +package dataconn + +import ( + "bytes" + "context" + "fmt" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/zfs" +) + +type Client struct { + log Logger + cn transport.Connecter +} + +func NewClient(connecter transport.Connecter, log Logger) *Client { + return &Client{ + log: log, + cn: connecter, + } +} + +func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error { + + var buf bytes.Buffer + _, memErr := buf.WriteString(endpoint) + if memErr != nil { + panic(memErr) + } + if err := conn.WriteStreamedMessage(ctx, &buf, ReqHeader); err != nil { + return err + } + + protobufBytes, err := proto.Marshal(req) + if err != nil { + return err + } + protobuf := bytes.NewBuffer(protobufBytes) + if err := conn.WriteStreamedMessage(ctx, protobuf, ReqStructured); err != nil { + return err + } + + if streamCopier != nil { + return conn.SendStream(ctx, streamCopier, ZFSStream) + } else { + return nil + } +} + +type RemoteHandlerError struct { + msg string +} + +func (e *RemoteHandlerError) Error() string { + return fmt.Sprintf("server error: %s", e.msg) +} + +type ProtocolError struct { + cause error +} + +func (e *ProtocolError) Error() string { + return fmt.Sprintf("protocol error: %s", e) +} + +func (c *Client) recv(ctx context.Context, conn *stream.Conn, res proto.Message) error { + + headerBuf, err := conn.ReadStreamedMessage(ctx, ResponseHeaderMaxSize, ResHeader) + if err != nil { + return err + } + header := string(headerBuf) + if strings.HasPrefix(header, responseHeaderHandlerErrorPrefix) { + // FIXME distinguishable error type + return &RemoteHandlerError{strings.TrimPrefix(header, responseHeaderHandlerErrorPrefix)} + } + if !strings.HasPrefix(header, responseHeaderHandlerOk) { + return &ProtocolError{fmt.Errorf("invalid header: %q", header)} + } + + protobuf, err := conn.ReadStreamedMessage(ctx, ResponseStructuredMaxSize, ResStructured) + if err != nil { + return err + } + if err := proto.Unmarshal(protobuf, res); err != nil { + return &ProtocolError{fmt.Errorf("cannot unmarshal structured part of response: %s", err)} + } + return nil +} + +func (c *Client) getWire(ctx context.Context) (*stream.Conn, error) { + nc, err := c.cn.Connect(ctx) + if err != nil { + return nil, err + } + conn := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout) + return conn, nil +} + +func (c *Client) putWire(conn *stream.Conn) { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing connection") + } +} + +func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + conn, err := c.getWire(ctx) + if err != nil { + return nil, nil, err + } + putWireOnReturn := true + defer func() { + if putWireOnReturn { + c.putWire(conn) + } + }() + + if err := c.send(ctx, conn, EndpointSend, req, nil); err != nil { + return nil, nil, err + } + + var res pdu.SendRes + if err := c.recv(ctx, conn, &res); err != nil { + return nil, nil, err + } + + var copier zfs.StreamCopier = nil + if !req.DryRun { + putWireOnReturn = false + copier = &streamCopier{streamConn: conn, closeStreamOnClose: true} + } + + return &res, copier, nil +} + +func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { + + defer c.log.Info("ReqRecv returns") + conn, err := c.getWire(ctx) + if err != nil { + return nil, err + } + + // send and recv response concurrently to catch early exists of remote handler + // (e.g. disk full, permission error, etc) + + type recvRes struct { + res *pdu.ReceiveRes + err error + } + recvErrChan := make(chan recvRes) + go func() { + res := &pdu.ReceiveRes{} + if err := c.recv(ctx, conn, res); err != nil { + recvErrChan <- recvRes{res, err} + } else { + recvErrChan <- recvRes{res, nil} + } + }() + + sendErrChan := make(chan error) + go func() { + if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil { + sendErrChan <- err + } else { + sendErrChan <- nil + } + }() + + var res recvRes + var sendErr error + var cause error // one of the above + didTryClose := false + for i := 0; i < 2; i++ { + select { + case res = <-recvErrChan: + c.log.WithField("errType", fmt.Sprintf("%T", res.err)).WithError(res.err).Debug("recv goroutine returned") + if res.err != nil && cause == nil { + cause = res.err + } + case sendErr = <-sendErrChan: + c.log.WithField("errType", fmt.Sprintf("%T", sendErr)).WithError(sendErr).Debug("send goroutine returned") + if sendErr != nil && cause == nil { + cause = sendErr + } + } + if !didTryClose && (res.err != nil || sendErr != nil) { + didTryClose = true + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("ReqRecv: cannot close connection, will likely block indefinitely") + } + c.log.WithError(err).Debug("ReqRecv: closed connection, should trigger other goroutine error") + } + } + + if !didTryClose { + // didn't close it in above loop, so we can give it back + c.putWire(conn) + } + + // if receive failed with a RemoteHandlerError, we know the transport was not broken + // => take the remote error as cause for the operation to fail + // TODO combine errors if send also failed + // (after all, send could have crashed on our side, rendering res.err a mere symptom of the cause) + if _, ok := res.err.(*RemoteHandlerError); ok { + cause = res.err + } + + return res.res, cause +} diff --git a/rpc/dataconn/dataconn_debug.go b/rpc/dataconn/dataconn_debug.go new file mode 100644 index 0000000..3c20701 --- /dev/null +++ b/rpc/dataconn/dataconn_debug.go @@ -0,0 +1,20 @@ +package dataconn + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/dataconn_server.go b/rpc/dataconn/dataconn_server.go new file mode 100644 index 0000000..41f5781 --- /dev/null +++ b/rpc/dataconn/dataconn_server.go @@ -0,0 +1,178 @@ +package dataconn + +import ( + "bytes" + "context" + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/zfs" +) + +// WireInterceptor has a chance to exchange the context and connection on each client connection. +type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (context.Context, *transport.AuthConn) + +// Handler implements the functionality that is exposed by Server to the Client. +type Handler interface { + // Send handles a SendRequest. + // The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run. + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) + // Receive handles a ReceiveRequest. + // It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout + // configured in ServerConfig.Shared.IdleConnTimeout. + Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) +} + +type Logger = logger.Logger + +type Server struct { + h Handler + wi WireInterceptor + log Logger +} + +func NewServer(wi WireInterceptor, logger Logger, handler Handler) *Server { + return &Server{ + h: handler, + wi: wi, + log: logger, + } +} + +// Serve consumes the listener, closes it as soon as ctx is closed. +// No accept errors are returned: they are logged to the Logger passed +// to the constructor. +func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { + + go func() { + <-ctx.Done() + s.log.Debug("context done") + if err := l.Close(); err != nil { + s.log.WithError(err).Error("cannot close listener") + } + }() + conns := make(chan *transport.AuthConn) + go func() { + for { + conn, err := l.Accept(ctx) + if err != nil { + if ctx.Done() != nil { + s.log.Debug("stop accepting after context is done") + return + } + s.log.WithError(err).Error("accept error") + continue + } + conns <- conn + } + }() + for conn := range conns { + go s.serveConn(conn) + } +} + +func (s *Server) serveConn(nc *transport.AuthConn) { + s.log.Debug("serveConn begin") + defer s.log.Debug("serveConn done") + + ctx := context.Background() + if s.wi != nil { + ctx, nc = s.wi(ctx, nc) + } + + c := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout) + defer func() { + s.log.Debug("close client connection") + if err := c.Close(); err != nil { + s.log.WithError(err).Error("cannot close client connection") + } + }() + + header, err := c.ReadStreamedMessage(ctx, RequestHeaderMaxSize, ReqHeader) + if err != nil { + s.log.WithError(err).Error("error reading structured part") + return + } + endpoint := string(header) + + reqStructured, err := c.ReadStreamedMessage(ctx, RequestStructuredMaxSize, ReqStructured) + if err != nil { + s.log.WithError(err).Error("error reading structured part") + return + } + + s.log.WithField("endpoint", endpoint).Debug("calling handler") + + var res proto.Message + var sendStream zfs.StreamCopier + var handlerErr error + switch endpoint { + case EndpointSend: + var req pdu.SendReq + if err := proto.Unmarshal(reqStructured, &req); err != nil { + s.log.WithError(err).Error("cannot unmarshal send request") + return + } + res, sendStream, handlerErr = s.h.Send(ctx, &req) // SHADOWING + case EndpointRecv: + var req pdu.ReceiveReq + if err := proto.Unmarshal(reqStructured, &req); err != nil { + s.log.WithError(err).Error("cannot unmarshal receive request") + return + } + res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING + default: + s.log.WithField("endpoint", endpoint).Error("unknown endpoint") + handlerErr = fmt.Errorf("requested endpoint does not exist") + return + } + + s.log.WithField("endpoint", endpoint).WithField("errType", fmt.Sprintf("%T", handlerErr)).Debug("handler returned") + + // prepare protobuf now to return the protobuf error in the header + // if marshaling fails. We consider failed marshaling a handler error + var protobuf *bytes.Buffer + if handlerErr == nil { + protobufBytes, err := proto.Marshal(res) + if err != nil { + s.log.WithError(err).Error("cannot marshal handler protobuf") + handlerErr = err + } + protobuf = bytes.NewBuffer(protobufBytes) // SHADOWING + } + + var resHeaderBuf bytes.Buffer + if handlerErr == nil { + resHeaderBuf.WriteString(responseHeaderHandlerOk) + } else { + resHeaderBuf.WriteString(responseHeaderHandlerErrorPrefix) + resHeaderBuf.WriteString(handlerErr.Error()) + } + if err := c.WriteStreamedMessage(ctx, &resHeaderBuf, ResHeader); err != nil { + s.log.WithError(err).Error("cannot write response header") + return + } + + if handlerErr != nil { + s.log.Debug("early exit after handler error") + return + } + + if err := c.WriteStreamedMessage(ctx, protobuf, ResStructured); err != nil { + s.log.WithError(err).Error("cannot write structured part of response") + return + } + + if sendStream != nil { + err := c.SendStream(ctx, sendStream, ZFSStream) + if err != nil { + s.log.WithError(err).Error("cannot write send stream") + } + } + + return +} diff --git a/rpc/dataconn/dataconn_shared.go b/rpc/dataconn/dataconn_shared.go new file mode 100644 index 0000000..0ea5a34 --- /dev/null +++ b/rpc/dataconn/dataconn_shared.go @@ -0,0 +1,70 @@ +package dataconn + +import ( + "io" + "sync" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/zfs" +) + +const ( + EndpointSend string = "/v1/send" + EndpointRecv string = "/v1/recv" +) + +const ( + ReqHeader uint32 = 1 + iota + ReqStructured + ResHeader + ResStructured + ZFSStream +) + +// Note that changing theses constants may break interop with other clients +// Aggressive with timing, conservative (future compatible) with buffer sizes +const ( + HeartbeatInterval = 5 * time.Second + HeartbeatPeerTimeout = 10 * time.Second + RequestHeaderMaxSize = 1 << 15 + RequestStructuredMaxSize = 1 << 22 + ResponseHeaderMaxSize = 1 << 15 + ResponseStructuredMaxSize = 1 << 23 +) + +// the following are protocol constants +const ( + responseHeaderHandlerOk = "HANDLER OK\n" + responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n" +) + +type streamCopier struct { + mtx sync.Mutex + used bool + streamConn *stream.Conn + closeStreamOnClose bool +} + +// WriteStreamTo implements zfs.StreamCopier +func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + s.mtx.Lock() + defer s.mtx.Unlock() + if s.used { + panic("streamCopier used mulitple times") + } + s.used = true + return s.streamConn.ReadStreamInto(w, ZFSStream) +} + +// Close implements zfs.StreamCopier +func (s *streamCopier) Close() error { + // only record the close here, what we do actually depends on whether + // the streamCopier is instantiated server-side or client-side + s.mtx.Lock() + defer s.mtx.Unlock() + if s.closeStreamOnClose { + return s.streamConn.Close() + } + return nil +} diff --git a/rpc/dataconn/dataconn_test.go b/rpc/dataconn/dataconn_test.go new file mode 100644 index 0000000..4d8820e --- /dev/null +++ b/rpc/dataconn/dataconn_test.go @@ -0,0 +1 @@ +package dataconn diff --git a/rpc/dataconn/frameconn/frameconn.go b/rpc/dataconn/frameconn/frameconn.go new file mode 100644 index 0000000..2736b98 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn.go @@ -0,0 +1,346 @@ +package frameconn + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" +) + +type FrameHeader struct { + Type uint32 + PayloadLen uint32 +} + +// The 4 MSBs of ft are reserved for frameconn. +func IsPublicFrameType(ft uint32) bool { + return (0xf<<28)&ft == 0 +} + +const ( + rstFrameType uint32 = 1<<28 + iota +) + +func assertPublicFrameType(frameType uint32) { + if !IsPublicFrameType(frameType) { + panic(fmt.Sprintf("frameconn: frame type %v cannot be used by consumers of this package", frameType)) + } +} + +func (f *FrameHeader) Unmarshal(buf []byte) { + if len(buf) != 8 { + panic(fmt.Sprintf("frame header is 8 bytes long")) + } + f.Type = binary.BigEndian.Uint32(buf[0:4]) + f.PayloadLen = binary.BigEndian.Uint32(buf[4:8]) +} + +type Conn struct { + readMtx, writeMtx sync.Mutex + nc timeoutconn.Conn + ncBuf *bufio.ReadWriter + readNextValid bool + readNext FrameHeader + nextReadErr error + bufPool *base2bufpool.Pool // no need for sync around it + shutdown shutdownFSM +} + +func Wrap(nc timeoutconn.Conn) *Conn { + return &Conn{ + nc: nc, + // ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)), + bufPool: base2bufpool.New(15, 22, base2bufpool.Allocate), // FIXME switch to Panic, but need to enforce the limits in recv for that. => need frameconn config + readNext: FrameHeader{}, + readNextValid: false, + } +} + +var ErrReadFrameLengthShort = errors.New("read frame length too short") +var ErrFixedFrameLengthMismatch = errors.New("read frame length mismatch") + +type Buffer struct { + bufpoolBuffer base2bufpool.Buffer + payloadLen uint32 +} + +func (b *Buffer) Free() { + b.bufpoolBuffer.Free() +} + +func (b *Buffer) Bytes() []byte { + return b.bufpoolBuffer.Bytes()[0:b.payloadLen] +} + +type Frame struct { + Header FrameHeader + Buffer Buffer +} + +var ErrShutdown = fmt.Errorf("frameconn: shutting down") + +// ReadFrame reads a frame from the connection. +// +// Due to an internal optimization (Readv, specifically), it is not guaranteed that a single call to +// WriteFrame unblocks a pending ReadFrame on an otherwise idle (empty) connection. +// The only way to guarantee that all previously written frames can reach the peer's layers on top +// of frameconn is to send an empty frame (no payload) and to ignore empty frames on the receiving side. +func (c *Conn) ReadFrame() (Frame, error) { + + if c.shutdown.IsShuttingDown() { + return Frame{}, ErrShutdown + } + + // only aquire readMtx now to prioritize the draining in Shutdown() + // over external callers (= drain public callers) + + c.readMtx.Lock() + defer c.readMtx.Unlock() + f, err := c.readFrame() + if f.Header.Type == rstFrameType { + c.shutdown.Begin() + return Frame{}, ErrShutdown + } + return f, err +} + +// callers must have readMtx locked +func (c *Conn) readFrame() (Frame, error) { + + if c.nextReadErr != nil { + ret := c.nextReadErr + c.nextReadErr = nil + return Frame{}, ret + } + + if !c.readNextValid { + var buf [8]byte + if _, err := io.ReadFull(c.nc, buf[:]); err != nil { + return Frame{}, err + } + c.readNext.Unmarshal(buf[:]) + c.readNextValid = true + } + + // read payload + next header + var nextHdrBuf [8]byte + buffer := c.bufPool.Get(uint(c.readNext.PayloadLen)) + bufferBytes := buffer.Bytes() + + if c.readNext.PayloadLen == 0 { + // This if statement implements the unlock-by-sending-empty-frame behavior + // documented in ReadFrame's public docs. + // + // It is crucial that we return this empty frame now: + // Consider the following plot with x-axis being time, + // P being a frame with payload, E one without, X either of P or E + // + // P P P P P P P E.....................X + // | | | | + // | | | F3 + // | | | + // | F2 |signficant time between frames because + // F1 the peer has nothing to say to us + // + // Assume we're at the point were F2's header is in c.readNext. + // That means F2 has not yet been returned. + // But because it is empty (no payload), we're already done reading it. + // If we omitted this if statement, the following would happen: + // Readv below would read [][]byte{[len(0)], [len(8)]). + + c.readNextValid = false + frame := Frame{ + Header: c.readNext, + Buffer: Buffer{ + bufpoolBuffer: buffer, + payloadLen: c.readNext.PayloadLen, // 0 + }, + } + return frame, nil + } + + noNextHeader := false + if n, err := c.nc.ReadvFull([][]byte{bufferBytes, nextHdrBuf[:]}); err != nil { + noNextHeader = true + zeroPayloadAndPeerClosed := n == 0 && c.readNext.PayloadLen == 0 && err == io.EOF + zeroPayloadAndNextFrameHeaderThenPeerClosed := err == io.EOF && c.readNext.PayloadLen == 0 && n == int64(len(nextHdrBuf)) + nonzeroPayloadRecvdButNextHeaderMissing := n > 0 && uint32(n) == c.readNext.PayloadLen + if zeroPayloadAndPeerClosed || zeroPayloadAndNextFrameHeaderThenPeerClosed || nonzeroPayloadRecvdButNextHeaderMissing { + // This is the last frame on the conn. + // Store the error to be returned on the next invocation of ReadFrame. + c.nextReadErr = err + // NORETURN, this frame is still valid + } else { + return Frame{}, err + } + } + + frame := Frame{ + Header: c.readNext, + Buffer: Buffer{ + bufpoolBuffer: buffer, + payloadLen: c.readNext.PayloadLen, + }, + } + + if !noNextHeader { + c.readNext.Unmarshal(nextHdrBuf[:]) + c.readNextValid = true + } else { + c.readNextValid = false + } + + return frame, nil +} + +func (c *Conn) WriteFrame(payload []byte, frameType uint32) error { + assertPublicFrameType(frameType) + if c.shutdown.IsShuttingDown() { + return ErrShutdown + } + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + return c.writeFrame(payload, frameType) +} + +func (c *Conn) writeFrame(payload []byte, frameType uint32) error { + var hdrBuf [8]byte + binary.BigEndian.PutUint32(hdrBuf[0:4], frameType) + binary.BigEndian.PutUint32(hdrBuf[4:8], uint32(len(payload))) + bufs := net.Buffers([][]byte{hdrBuf[:], payload}) + if _, err := c.nc.WritevFull(bufs); err != nil { + return err + } + return nil +} + +func (c *Conn) Shutdown(deadline time.Time) error { + // TCP connection teardown is a bit wonky if we are in a situation + // where there is still data in flight (DIF) to our side: + // If we just close the connection, our kernel will send RSTs + // in response to the DIF, and those RSTs may reach the client's + // kernel faster than the client app is able to pull the + // last bytes from its kernel TCP receive buffer. + // + // Therefore, we send a frame with type rstFrameType to indicate + // that the connection is to be closed immediately, and further + // use CloseWrite instead of Close. + // As per definition of the wire interface, CloseWrite guarantees + // delivery of the data in our kernel TCP send buffer. + // Therefore, the client always receives the RST frame. + // + // Now what are we going to do after that? + // + // 1. Naive Option: We just call Close() right after CloseWrite. + // This yields the same race condition as explained above (DIF, first + // paragraph): The situation just becomae a little more unlikely because + // our rstFrameType + CloseWrite dance gave the client a full RTT worth of + // time to read the data from its TCP recv buffer. + // + // 2. Correct Option: Drain the read side until io.EOF + // We can read from the unclosed read-side of the connection until we get + // the io.EOF caused by the (well behaved) client closing the connection + // in response to it reading the rstFrameType frame we sent. + // However, this wastes resources on our side (we don't care about the + // pending data anymore), and has potential for (D)DoS through CPU-time + // exhaustion if the client just keeps sending data. + // Then again, this option has the advantage with well-behaved clients + // that we do not waste precious kernel-memory on the stale receive buffer + // on our side (which is still full of data that we do not intend to read). + // + // 2.1 DoS Mitigation: Bound the number of bytes to drain, then close + // At the time of writing, this technique is practiced by the Go http server + // implementation, and actually SHOULDed in the HTTP 1.1 RFC. It is + // important to disable the idle timeout of the underlying timeoutconn in + // that case and set an absolute deadline by which the socket must have + // been fully drained. Not too hard, though ;) + // + // 2.2: Client sends RST, not FIN when it receives an rstFrameTyp frame. + // We can use wire.(*net.TCPConn).SetLinger(0) to force an RST to be sent + // on a subsequent close (instead of a FIN + wait for FIN+ACK). + // TODO put this into Wire interface as an abstract method. + // + // 2.3 Only start draining after N*RTT + // We have an RTT approximation from Wire.CloseWrite, which by definition + // must not return before all to-be-sent-data has been acknowledged by the + // client. Give the client a fair chance to react, and only start draining + // after a multiple of the RTT has elapsed. + // We waste the recv buffer memory a little longer than necessary, iff the + // client reacts faster than expected. But we don't wast CPU time. + // If we apply 2.2, we'll also have the benefit that our kernel will have + // dropped the recv buffer memory as soon as it receives the client's RST. + // + // 3. TCP-only: OOB-messaging + // We can use TCP's 'urgent' flag in the client to acknowledge the receipt + // of the rstFrameType to us. + // We can thus wait for that signal while leaving the kernel buffer as is. + + // TODO: For now, we just drain the connection (Option 2), + // but we enforce deadlines so the _time_ we drain the connection + // is bounded, although we do _that_ at full speed + + defer prometheus.NewTimer(prom.ShutdownSeconds).ObserveDuration() + + closeWire := func(step string) error { + // TODO SetLinger(0) or similiar (we want RST frames here, not FINS) + if closeErr := c.nc.Close(); closeErr != nil { + prom.ShutdownCloseErrors.WithLabelValues("close").Inc() + return closeErr + } + return nil + } + + hardclose := func(err error, step string) error { + prom.ShutdownHardCloses.WithLabelValues(step).Inc() + return closeWire(step) + } + + c.shutdown.Begin() + // new calls to c.ReadFrame and c.WriteFrame will now return ErrShutdown + // Aquiring writeMtx and readMtx ensures that the last calls exit successfully + + // disable renewing timeouts now, enforce the requested deadline instead + // we need to do this before aquiring locks to enforce the timeout on slow + // clients / if something hangs (DoS mitigation) + if err := c.nc.DisableTimeouts(); err != nil { + return hardclose(err, "disable_timeouts") + } + if err := c.nc.SetDeadline(deadline); err != nil { + return hardclose(err, "set_deadline") + } + + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + + if err := c.writeFrame([]byte{}, rstFrameType); err != nil { + return hardclose(err, "write_frame") + } + + if err := c.nc.CloseWrite(); err != nil { + return hardclose(err, "close_write") + } + + c.readMtx.Lock() + defer c.readMtx.Unlock() + + // TODO DoS mitigation: wait for client acknowledgement that they initiated Shutdown, + // then perform abortive close on our side. As explained above, probably requires + // OOB signaling such as TCP's urgent flag => transport-specific? + + // TODO DoS mitigation by reading limited number of bytes + // see discussion above why this is non-trivial + defer prometheus.NewTimer(prom.ShutdownDrainSeconds).ObserveDuration() + n, _ := io.Copy(ioutil.Discard, c.nc) + prom.ShutdownDrainBytesRead.Observe(float64(n)) + + return closeWire("close") +} diff --git a/rpc/dataconn/frameconn/frameconn_prometheus.go b/rpc/dataconn/frameconn/frameconn_prometheus.go new file mode 100644 index 0000000..d5c2c50 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_prometheus.go @@ -0,0 +1,63 @@ +package frameconn + +import "github.com/prometheus/client_golang/prometheus" + +var prom struct { + ShutdownDrainBytesRead prometheus.Summary + ShutdownSeconds prometheus.Summary + ShutdownDrainSeconds prometheus.Summary + ShutdownHardCloses *prometheus.CounterVec + ShutdownCloseErrors *prometheus.CounterVec +} + +func init() { + prom.ShutdownDrainBytesRead = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_drain_bytes_read", + Help: "Number of bytes read during the drain phase of connection shutdown", + }) + prom.ShutdownSeconds = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_seconds", + Help: "Seconds it took for connection shutdown to complete", + }) + prom.ShutdownDrainSeconds = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_drain_seconds", + Help: "Seconds it took from read-side-drain until shutdown completion", + }) + prom.ShutdownHardCloses = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_hard_closes", + Help: "Number of hard connection closes during shutdown (abortive close)", + }, []string{"step"}) + prom.ShutdownCloseErrors = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_close_errors", + Help: "Number of errors closing the underlying network connection. Should alert on this", + }, []string{"step"}) +} + +func PrometheusRegister(registry prometheus.Registerer) error { + if err := registry.Register(prom.ShutdownDrainBytesRead); err != nil { + return err + } + if err := registry.Register(prom.ShutdownSeconds); err != nil { + return err + } + if err := registry.Register(prom.ShutdownDrainSeconds); err != nil { + return err + } + if err := registry.Register(prom.ShutdownHardCloses); err != nil { + return err + } + if err := registry.Register(prom.ShutdownCloseErrors); err != nil { + return err + } + return nil +} diff --git a/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go b/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go new file mode 100644 index 0000000..980d980 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go @@ -0,0 +1,37 @@ +package frameconn + +import "sync" + +type shutdownFSM struct { + mtx sync.Mutex + state shutdownFSMState +} + +type shutdownFSMState uint32 + +const ( + shutdownStateOpen shutdownFSMState = iota + shutdownStateBegin +) + +func newShutdownFSM() *shutdownFSM { + fsm := &shutdownFSM{ + state: shutdownStateOpen, + } + return fsm +} + +func (f *shutdownFSM) Begin() (thisCallStartedShutdown bool) { + f.mtx.Lock() + defer f.mtx.Unlock() + thisCallStartedShutdown = f.state != shutdownStateOpen + f.state = shutdownStateBegin + return thisCallStartedShutdown +} + +func (f *shutdownFSM) IsShuttingDown() bool { + f.mtx.Lock() + defer f.mtx.Unlock() + return f.state != shutdownStateOpen +} + diff --git a/rpc/dataconn/frameconn/frameconn_test.go b/rpc/dataconn/frameconn/frameconn_test.go new file mode 100644 index 0000000..b070c79 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_test.go @@ -0,0 +1,22 @@ +package frameconn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsPublicFrameType(t *testing.T) { + for i := uint32(0); i < 256; i++ { + i := i + t.Run(fmt.Sprintf("^%d", i), func(t *testing.T) { + assert.False(t, IsPublicFrameType(^i)) + }) + } + assert.True(t, IsPublicFrameType(0)) + assert.True(t, IsPublicFrameType(1)) + assert.True(t, IsPublicFrameType(255)) + assert.False(t, IsPublicFrameType(rstFrameType)) +} + diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn.go b/rpc/dataconn/heartbeatconn/heartbeatconn.go new file mode 100644 index 0000000..2924fdc --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn.go @@ -0,0 +1,137 @@ +package heartbeatconn + +import ( + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" +) + +type Conn struct { + state state + // if not nil, opErr is returned for ReadFrame and WriteFrame (not for Close, though) + opErr atomic.Value // error + fc *frameconn.Conn + sendInterval, timeout time.Duration + stopSend chan struct{} + lastFrameSent atomic.Value // time.Time +} + +type HeartbeatTimeout struct{} + +func (e HeartbeatTimeout) Error() string { + return "heartbeat timeout" +} + +func (e HeartbeatTimeout) Temporary() bool { return true } + +func (e HeartbeatTimeout) Timeout() bool { return true } + +var _ net.Error = HeartbeatTimeout{} + +type state = int32 + +const ( + stateInitial state = 0 + stateClosed state = 2 +) + +const ( + heartbeat uint32 = 1 << 24 +) + +// The 4 MSBs of ft are reserved for frameconn, we reserve the next 4 MSB for us. +func IsPublicFrameType(ft uint32) bool { + return frameconn.IsPublicFrameType(ft) && (0xf<<24)&ft == 0 +} + +func assertPublicFrameType(frameType uint32) { + if !IsPublicFrameType(frameType) { + panic(fmt.Sprintf("heartbeatconn: frame type %v cannot be used by consumers of this package", frameType)) + } +} + +func Wrap(nc timeoutconn.Wire, sendInterval, timeout time.Duration) *Conn { + c := &Conn{ + fc: frameconn.Wrap(timeoutconn.Wrap(nc, timeout)), + stopSend: make(chan struct{}), + sendInterval: sendInterval, + timeout: timeout, + } + c.lastFrameSent.Store(time.Now()) + go c.sendHeartbeats() + return c +} + +func (c *Conn) Shutdown() error { + normalClose := atomic.CompareAndSwapInt32(&c.state, stateInitial, stateClosed) + if normalClose { + close(c.stopSend) + } + return c.fc.Shutdown(time.Now().Add(c.timeout)) +} + +// started as a goroutine in constructor +func (c *Conn) sendHeartbeats() { + sleepTime := func(now time.Time) time.Duration { + lastSend := c.lastFrameSent.Load().(time.Time) + return lastSend.Add(c.sendInterval).Sub(now) + } + timer := time.NewTimer(sleepTime(time.Now())) + defer timer.Stop() + for { + select { + case <-c.stopSend: + return + case now := <-timer.C: + func() { + defer func() { + timer.Reset(sleepTime(time.Now())) + }() + if sleepTime(now) > 0 { + return + } + debug("send heartbeat") + // if the connection is in zombie mode (aka iptables DROP inbetween peers) + // this call or one of its successors will block after filling up the kernel tx buffer + c.fc.WriteFrame([]byte{}, heartbeat) + // ignore errors from WriteFrame to rate-limit SendHeartbeat retries + c.lastFrameSent.Store(time.Now()) + }() + } + } +} + +func (c *Conn) ReadFrame() (frameconn.Frame, error) { + return c.readFrameFiltered() +} + +func (c *Conn) readFrameFiltered() (frameconn.Frame, error) { + for { + f, err := c.fc.ReadFrame() + if err != nil { + return frameconn.Frame{}, err + } + if IsPublicFrameType(f.Header.Type) { + return f, nil + } + if f.Header.Type != heartbeat { + return frameconn.Frame{}, fmt.Errorf("unknown frame type %x", f.Header.Type) + } + // drop heartbeat frame + debug("received heartbeat") + continue + } +} + +func (c *Conn) WriteFrame(payload []byte, frameType uint32) error { + assertPublicFrameType(frameType) + err := c.fc.WriteFrame(payload, frameType) + if err == nil { + c.lastFrameSent.Store(time.Now()) + } + return err +} diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go b/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go new file mode 100644 index 0000000..6bdea8d --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go @@ -0,0 +1,20 @@ +package heartbeatconn + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_HEARTBEATCONN_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn/heartbeatconn: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn_test.go b/rpc/dataconn/heartbeatconn/heartbeatconn_test.go new file mode 100644 index 0000000..8b02459 --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn_test.go @@ -0,0 +1,26 @@ +package heartbeatconn + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" +) + +func TestFrameTypes(t *testing.T) { + assert.True(t, frameconn.IsPublicFrameType(heartbeat)) +} + +func TestNegativeTimer(t *testing.T) { + + timer := time.NewTimer(-1 * time.Second) + defer timer.Stop() + time.Sleep(100 * time.Millisecond) + select { + case <-timer.C: + t.Log("timer with negative time fired, that's what we want") + default: + t.Fail() + } +} diff --git a/rpc/dataconn/microbenchmark/microbenchmark.go b/rpc/dataconn/microbenchmark/microbenchmark.go new file mode 100644 index 0000000..287f7e1 --- /dev/null +++ b/rpc/dataconn/microbenchmark/microbenchmark.go @@ -0,0 +1,184 @@ +// microbenchmark to manually test rpc/dataconn perforamnce +// +// With stdin / stdout on client and server, simulating zfs send|recv piping +// +// ./microbenchmark -appmode server | pv -r > /dev/null +// ./microbenchmark -appmode client -direction recv < /dev/zero +// +// +// Without the overhead of pipes (just protocol perforamnce, mostly useful with perf bc no bw measurement) +// +// ./microbenchmark -appmode client -direction recv -devnoopWriter -devnoopReader +// ./microbenchmark -appmode server -devnoopReader -devnoopWriter +// +package main + +import ( + "context" + "flag" + "fmt" + "io" + "net" + "os" + + "github.com/pkg/profile" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/devnoop" + "github.com/zrepl/zrepl/zfs" +) + +func orDie(err error) { + if err != nil { + panic(err) + } +} + +type readerStreamCopier struct{ io.Reader } + +func (readerStreamCopier) Close() error { return nil } + +type readerStreamCopierErr struct { + error +} + +func (readerStreamCopierErr) IsReadError() bool { return false } +func (readerStreamCopierErr) IsWriteError() bool { return true } + +func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + var buf [1 << 21]byte + _, err := io.CopyBuffer(w, c.Reader, buf[:]) + // always assume write error + return readerStreamCopierErr{err} +} + +type devNullHandler struct{} + +func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + var res pdu.SendRes + if args.devnoopReader { + return &res, readerStreamCopier{devnoop.Get()}, nil + } else { + return &res, readerStreamCopier{os.Stdin}, nil + } +} + +func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) { + var out io.Writer = os.Stdout + if args.devnoopWriter { + out = devnoop.Get() + } + err := stream.WriteStreamTo(out) + var res pdu.ReceiveRes + return &res, err +} + +type tcpConnecter struct { + addr string +} + +func (c tcpConnecter) Connect(ctx context.Context) (timeoutconn.Wire, error) { + conn, err := net.Dial("tcp", c.addr) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil +} + +type tcpListener struct { + nl *net.TCPListener + clientIdent string +} + +func (l tcpListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + tcpconn, err := l.nl.AcceptTCP() + orDie(err) + return transport.NewAuthConn(tcpconn, l.clientIdent), nil +} + +func (l tcpListener) Addr() net.Addr { return l.nl.Addr() } + +func (l tcpListener) Close() error { return l.nl.Close() } + +var args struct { + addr string + appmode string + direction string + profile bool + devnoopReader bool + devnoopWriter bool +} + +func server() { + + log := logger.NewStderrDebugLogger() + log.Debug("starting server") + nl, err := net.Listen("tcp", args.addr) + orDie(err) + l := tcpListener{nl.(*net.TCPListener), "fakeclientidentity"} + + srv := dataconn.NewServer(nil, logger.NewStderrDebugLogger(), devNullHandler{}) + + ctx := context.Background() + + srv.Serve(ctx, l) + +} + +func main() { + + flag.BoolVar(&args.profile, "profile", false, "") + flag.BoolVar(&args.devnoopReader, "devnoopReader", false, "") + flag.BoolVar(&args.devnoopWriter, "devnoopWriter", false, "") + flag.StringVar(&args.addr, "address", ":8888", "") + flag.StringVar(&args.appmode, "appmode", "client|server", "") + flag.StringVar(&args.direction, "direction", "", "send|recv") + flag.Parse() + + if args.profile { + defer profile.Start(profile.CPUProfile).Stop() + } + + switch args.appmode { + case "client": + client() + case "server": + server() + default: + orDie(fmt.Errorf("unknown appmode %q", args.appmode)) + } +} + +func client() { + + logger := logger.NewStderrDebugLogger() + ctx := context.Background() + + connecter := tcpConnecter{args.addr} + client := dataconn.NewClient(connecter, logger) + + switch args.direction { + case "send": + req := pdu.SendReq{} + _, stream, err := client.ReqSend(ctx, &req) + orDie(err) + err = stream.WriteStreamTo(os.Stdout) + orDie(err) + case "recv": + var r io.Reader = os.Stdin + if args.devnoopReader { + r = devnoop.Get() + } + s := readerStreamCopier{r} + req := pdu.ReceiveReq{} + _, err := client.ReqRecv(ctx, &req, &s) + orDie(err) + default: + orDie(fmt.Errorf("unknown direction%q", args.direction)) + } + +} diff --git a/rpc/dataconn/stream/stream.go b/rpc/dataconn/stream/stream.go new file mode 100644 index 0000000..10e7ff0 --- /dev/null +++ b/rpc/dataconn/stream/stream.go @@ -0,0 +1,269 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "strings" + "unicode/utf8" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/zfs" +) + +type Logger = logger.Logger + +type contextKey int + +const ( + contextKeyLogger contextKey = 1 + iota +) + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLogger, log) +} + +func getLog(ctx context.Context) Logger { + log, ok := ctx.Value(contextKeyLogger).(Logger) + if !ok { + log = logger.NewNullLogger() + } + return log +} + +// Frame types used by this package. +// 4 MSBs are reserved for frameconn, next 4 MSB for heartbeatconn, next 4 MSB for us. +const ( + StreamErrTrailer uint32 = 1 << (16 + iota) + End + // max 16 +) + +// NOTE: make sure to add a tests for each frame type that checks +// whether it is heartbeatconn.IsPublicFrameType() + +// Check whether the given frame type is allowed to be used by +// consumers of this package. Intended for use in unit tests. +func IsPublicFrameType(ft uint32) bool { + return frameconn.IsPublicFrameType(ft) && heartbeatconn.IsPublicFrameType(ft) && ((0xf<<16)&ft == 0) +} + +const FramePayloadShift = 19 + +var bufpool = base2bufpool.New(FramePayloadShift, FramePayloadShift, base2bufpool.Panic) + +// if sendStream returns an error, that error will be sent as a trailer to the client +// ok will return nil, though. +func writeStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) { + debug("writeStream: enter stype=%v", stype) + defer debug("writeStream: return") + if stype == 0 { + panic("stype must be non-zero") + } + if !IsPublicFrameType(stype) { + panic(fmt.Sprintf("stype %v is not public", stype)) + } + return doWriteStream(ctx, c, stream, stype) +} + +func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) { + + // RULE1 (buf == ) XOR (err == nil) + type read struct { + buf base2bufpool.Buffer + err error + } + + reads := make(chan read, 5) + go func() { + for { + buffer := bufpool.Get(1 << FramePayloadShift) + bufferBytes := buffer.Bytes() + n, err := io.ReadFull(stream, bufferBytes) + buffer.Shrink(uint(n)) + // if we received anything, send one read without an error (RULE 1) + if n > 0 { + reads <- read{buffer, nil} + } + if err == io.ErrUnexpectedEOF { + // happens iff io.ReadFull read io.EOF from stream + err = io.EOF + } + if err != nil { + reads <- read{err: err} // RULE1 + close(reads) + return + } + } + }() + + for read := range reads { + if read.err == nil { + // RULE 1: read.buf is valid + // next line is the hot path... + writeErr := c.WriteFrame(read.buf.Bytes(), stype) + read.buf.Free() + if writeErr != nil { + return nil, writeErr + } + continue + } else if read.err == io.EOF { + if err := c.WriteFrame([]byte{}, End); err != nil { + return nil, err + } + break + } else { + errReader := strings.NewReader(read.err.Error()) + errReadErrReader, errConnWrite := doWriteStream(ctx, c, errReader, StreamErrTrailer) + if errReadErrReader != nil { + panic(errReadErrReader) // in-memory, cannot happen + } + return read.err, errConnWrite + } + } + + return nil, nil +} + +type ReadStreamErrorKind int + +const ( + ReadStreamErrorKindConn ReadStreamErrorKind = 1 + iota + ReadStreamErrorKindWrite + ReadStreamErrorKindSource + ReadStreamErrorKindStreamErrTrailerEncoding + ReadStreamErrorKindUnexpectedFrameType +) + +type ReadStreamError struct { + Kind ReadStreamErrorKind + Err error +} + +func (e *ReadStreamError) Error() string { + kindStr := "" + switch e.Kind { + case ReadStreamErrorKindConn: + kindStr = " read error: " + case ReadStreamErrorKindWrite: + kindStr = " write error: " + case ReadStreamErrorKindSource: + kindStr = " source error: " + case ReadStreamErrorKindStreamErrTrailerEncoding: + kindStr = " source implementation error: " + case ReadStreamErrorKindUnexpectedFrameType: + kindStr = " protocol error: " + } + return fmt.Sprintf("stream:%s%s", kindStr, e.Err) +} + +var _ net.Error = &ReadStreamError{} + +func (e ReadStreamError) netErr() net.Error { + if netErr, ok := e.Err.(net.Error); ok { + return netErr + } + return nil +} + +func (e ReadStreamError) Timeout() bool { + if netErr := e.netErr(); netErr != nil { + return netErr.Timeout() + } + return false +} + +func (e ReadStreamError) Temporary() bool { + if netErr := e.netErr(); netErr != nil { + return netErr.Temporary() + } + return false +} + +var _ zfs.StreamCopierError = &ReadStreamError{} + +func (e ReadStreamError) IsReadError() bool { + return e.Kind != ReadStreamErrorKindWrite +} + +func (e ReadStreamError) IsWriteError() bool { + return e.Kind == ReadStreamErrorKindWrite +} + +type readFrameResult struct { + f frameconn.Frame + err error +} + +func readFrames(reads chan<- readFrameResult, c *heartbeatconn.Conn) { + for { + var r readFrameResult + r.f, r.err = c.ReadFrame() + reads <- r + if r.err != nil { + return + } + } +} + +// ReadStream will close c if an error reading from c or writing to receiver occurs +// +// readStream calls itself recursively to read multi-frame error trailers +// Thus, the reads channel needs to be a parameter. +func readStream(reads <-chan readFrameResult, c *heartbeatconn.Conn, receiver io.Writer, stype uint32) *ReadStreamError { + + var f frameconn.Frame + for read := range reads { + debug("readStream: read frame %v %v", read.f.Header, read.err) + f = read.f + if read.err != nil { + return &ReadStreamError{ReadStreamErrorKindConn, read.err} + } + if f.Header.Type != stype { + break + } + + n, err := receiver.Write(f.Buffer.Bytes()) + if err != nil { + f.Buffer.Free() + return &ReadStreamError{ReadStreamErrorKindWrite, err} // FIXME wrap as writer error + } + if n != len(f.Buffer.Bytes()) { + f.Buffer.Free() + return &ReadStreamError{ReadStreamErrorKindWrite, io.ErrShortWrite} + } + f.Buffer.Free() + } + + if f.Header.Type == End { + debug("readStream: End reached") + return nil + } + + if f.Header.Type == StreamErrTrailer { + debug("readStream: begin of StreamErrTrailer") + var errBuf bytes.Buffer + if n, err := errBuf.Write(f.Buffer.Bytes()); n != len(f.Buffer.Bytes()) || err != nil { + panic(fmt.Sprintf("unexpected bytes.Buffer write error: %v %v", n, err)) + } + // recursion ftw! we won't enter this if stmt because stype == StreamErrTrailer in the following call + rserr := readStream(reads, c, &errBuf, StreamErrTrailer) + if rserr != nil && rserr.Kind == ReadStreamErrorKindWrite { + panic(fmt.Sprintf("unexpected bytes.Buffer write error: %s", rserr)) + } else if rserr != nil { + debug("readStream: rserr != nil && != ReadStreamErrorKindWrite: %v %v\n", rserr.Kind, rserr.Err) + return rserr + } + if !utf8.Valid(errBuf.Bytes()) { + return &ReadStreamError{ReadStreamErrorKindStreamErrTrailerEncoding, fmt.Errorf("source error, but not encoded as UTF-8")} + } + return &ReadStreamError{ReadStreamErrorKindSource, fmt.Errorf("%s", errBuf.String())} + } + + return &ReadStreamError{ReadStreamErrorKindUnexpectedFrameType, fmt.Errorf("unexpected frame type %v (expected %v)", f.Header.Type, stype)} +} diff --git a/rpc/dataconn/stream/stream_conn.go b/rpc/dataconn/stream/stream_conn.go new file mode 100644 index 0000000..ce0c052 --- /dev/null +++ b/rpc/dataconn/stream/stream_conn.go @@ -0,0 +1,194 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/zfs" +) + +type Conn struct { + hc *heartbeatconn.Conn + + // whether the per-conn readFrames goroutine completed + waitReadFramesDone chan struct{} + // filled by per-conn readFrames goroutine + frameReads chan readFrameResult + + // readMtx serializes read stream operations because we inherently only + // support a single stream at a time over hc. + readMtx sync.Mutex + readClean bool + allowWriteStreamTo bool + + // writeMtx serializes write stream operations because we inherently only + // support a single stream at a time over hc. + writeMtx sync.Mutex + writeClean bool +} + +var readMessageSentinel = fmt.Errorf("read stream complete") + +type writeStreamToErrorUnknownState struct{} + +func (e writeStreamToErrorUnknownState) Error() string { + return "dataconn read stream: connection is in unknown state" +} + +func (e writeStreamToErrorUnknownState) IsReadError() bool { return true } + +func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false } + +func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn { + hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout) + conn := &Conn{ + hc: hc, readClean: true, writeClean: true, + waitReadFramesDone: make(chan struct{}), + frameReads: make(chan readFrameResult, 5), // FIXME constant + } + go conn.readFrames() + return conn +} + +func isConnCleanAfterRead(res *ReadStreamError) bool { + return res == nil || res.Kind == ReadStreamErrorKindSource || res.Kind == ReadStreamErrorKindStreamErrTrailerEncoding +} + +func isConnCleanAfterWrite(err error) bool { + return err == nil +} + +var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped") + +func (c *Conn) readFrames() { + defer close(c.waitReadFramesDone) + defer close(c.frameReads) + readFrames(c.frameReads, c.hc) +} + +func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) { + c.readMtx.Lock() + defer c.readMtx.Unlock() + if !c.readClean { + return nil, &ReadStreamError{ + Kind: ReadStreamErrorKindConn, + Err: fmt.Errorf("dataconn read message: connection is in unknown state"), + } + } + + r, w := io.Pipe() + var buf bytes.Buffer + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + lr := io.LimitReader(r, int64(maxSize)) + if _, err := io.Copy(&buf, lr); err != nil && err != readMessageSentinel { + panic(err) + } + }() + err := readStream(c.frameReads, c.hc, w, frameType) + c.readClean = isConnCleanAfterRead(err) + w.CloseWithError(readMessageSentinel) + wg.Wait() + if err != nil { + return nil, err + } else { + return buf.Bytes(), nil + } +} + +// WriteStreamTo reads a stream from Conn and writes it to w. +func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError { + c.readMtx.Lock() + defer c.readMtx.Unlock() + if !c.readClean { + return writeStreamToErrorUnknownState{} + } + var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) + c.readClean = isConnCleanAfterRead(err) + + // https://golang.org/doc/faq#nil_error + if err == nil { + return nil + } + return err +} + +func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error { + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + if !c.writeClean { + return fmt.Errorf("dataconn write message: connection is in unknown state") + } + errBuf, errConn := writeStream(ctx, c.hc, buf, frameType) + if errBuf != nil { + panic(errBuf) + } + c.writeClean = isConnCleanAfterWrite(errConn) + return errConn +} + +func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error { + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + if !c.writeClean { + return fmt.Errorf("dataconn send stream: connection is in unknown state") + } + + // avoid io.Pipe if zfs.StreamCopier is an io.Reader + var r io.Reader + var w *io.PipeWriter + streamCopierErrChan := make(chan zfs.StreamCopierError, 1) + if reader, ok := src.(io.Reader); ok { + r = reader + streamCopierErrChan <- nil + close(streamCopierErrChan) + } else { + r, w = io.Pipe() + go func() { + streamCopierErrChan <- src.WriteStreamTo(w) + w.Close() + }() + } + + type writeStreamRes struct { + errStream, errConn error + } + writeStreamErrChan := make(chan writeStreamRes, 1) + go func() { + var res writeStreamRes + res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType) + if w != nil { + w.CloseWithError(res.errStream) + } + writeStreamErrChan <- res + }() + + writeRes := <-writeStreamErrChan + streamCopierErr := <-streamCopierErrChan + c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct? + if streamCopierErr != nil && streamCopierErr.IsReadError() { + return streamCopierErr // something on our side is bad + } else { + if writeRes.errStream != nil { + return writeRes.errStream + } else if writeRes.errConn != nil { + return writeRes.errConn + } + // TODO combined error? + return streamCopierErr + } +} + +func (c *Conn) Close() error { + err := c.hc.Shutdown() + <-c.waitReadFramesDone + return err +} diff --git a/rpc/dataconn/stream/stream_debug.go b/rpc/dataconn/stream/stream_debug.go new file mode 100644 index 0000000..c1e2ef9 --- /dev/null +++ b/rpc/dataconn/stream/stream_debug.go @@ -0,0 +1,20 @@ +package stream + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_STREAM_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn/stream: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/stream/stream_test.go b/rpc/dataconn/stream/stream_test.go new file mode 100644 index 0000000..7da4cfa --- /dev/null +++ b/rpc/dataconn/stream/stream_test.go @@ -0,0 +1,131 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/util/socketpair" +) + +func TestFrameTypesOk(t *testing.T) { + t.Logf("%v", End) + assert.True(t, heartbeatconn.IsPublicFrameType(End)) + assert.True(t, heartbeatconn.IsPublicFrameType(StreamErrTrailer)) +} + +func TestStreamer(t *testing.T) { + + anc, bnc, err := socketpair.SocketPair() + require.NoError(t, err) + + hto := 1 * time.Hour + a := heartbeatconn.Wrap(anc, hto, hto) + b := heartbeatconn.Wrap(bnc, hto, hto) + + log := logger.NewStderrDebugLogger() + ctx := WithLogger(context.Background(), log) + + stype := uint32(0x23) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var buf bytes.Buffer + buf.Write( + bytes.Repeat([]byte{1, 2}, 1<<25), + ) + writeStream(ctx, a, &buf, stype) + log.Debug("WriteStream returned") + a.Shutdown() + }() + + go func() { + defer wg.Done() + var buf bytes.Buffer + ch := make(chan readFrameResult, 5) + wg.Add(1) + go func() { + defer wg.Done() + readFrames(ch, b) + }() + err := readStream(ch, b, &buf, stype) + log.WithField("errType", fmt.Sprintf("%T %v", err, err)).Debug("ReadStream returned") + assert.Nil(t, err) + expected := bytes.Repeat([]byte{1, 2}, 1<<25) + assert.True(t, bytes.Equal(expected, buf.Bytes())) + b.Shutdown() + }() + + wg.Wait() + +} + +type errReader struct { + t *testing.T + readErr error +} + +func (er errReader) Read(p []byte) (n int, err error) { + er.t.Logf("errReader.Read called") + return 0, er.readErr +} + +func TestMultiFrameStreamErrTraileror(t *testing.T) { + anc, bnc, err := socketpair.SocketPair() + require.NoError(t, err) + + hto := 1 * time.Hour + a := heartbeatconn.Wrap(anc, hto, hto) + b := heartbeatconn.Wrap(bnc, hto, hto) + + log := logger.NewStderrDebugLogger() + ctx := WithLogger(context.Background(), log) + + longErr := fmt.Errorf("an error that definitley spans more than one frame:\n%s", strings.Repeat("a\n", 1<<4)) + + stype := uint32(0x23) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r := errReader{t, longErr} + writeStream(ctx, a, &r, stype) + a.Shutdown() + }() + + go func() { + defer wg.Done() + defer b.Shutdown() + var buf bytes.Buffer + ch := make(chan readFrameResult, 5) + wg.Add(1) + go func() { + defer wg.Done() + readFrames(ch, b) + }() + err := readStream(ch, b, &buf, stype) + t.Logf("%s", err) + require.NotNil(t, err) + assert.True(t, buf.Len() == 0) + assert.Equal(t, err.Kind, ReadStreamErrorKindSource) + receivedErr := err.Err.Error() + expectedErr := longErr.Error() + assert.True(t, receivedErr == expectedErr) // builtin Equals is too slow + if receivedErr != expectedErr { + t.Logf("lengths: %v %v", len(receivedErr), len(expectedErr)) + } + }() + + wg.Wait() +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore new file mode 100644 index 0000000..650f629 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore @@ -0,0 +1,12 @@ +# setup-specific +inventory +*.retry + +# generated by gen_files.sh +files/*ssh_client_identity +files/*ssh_client_identity.pub +files/*.tls.*.key +files/*.tls.*.csr +files/*.tls.*.crt +files/wireevaluator + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md new file mode 100644 index 0000000..958c539 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md @@ -0,0 +1,15 @@ +This directory contains very hacky test automation for wireevaluator based on nested Ansible playbooks. + +* Copy `inventory.example` to `inventory` +* Adjust `inventory` IP addresses as needed +* Make sure there's an OpenSSH server running on the serve host +* Make sure there's no firewalling whatsoever between the hosts +* Run `GENKEYS=1 ./gen_files.sh` to re-generate self-signed TLS certs +* Run the following command, adjusting the `wireevaluator_repeat` value to the number of times you want to repeat each test + +``` +ansible-playbook -i inventory all.yml -e `wireevaluator_repeat=3` +``` + +Generally, things are fine if the playbook doesn't show any panics from wireevaluator. + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml new file mode 100644 index 0000000..9abe411 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml @@ -0,0 +1,17 @@ +- hosts: connect,serve + tasks: + + - name: "run test" + include: internal_prepare_and_run_repeated.yml + wireevaluator_transport: "{{config.0}}" + wireevaluator_case: "{{config.1}}" + wireevaluator_repeat: "{{wireevaluator_repeat}}" + with_cartesian: + - [ tls, ssh, tcp ] + - + - closewrite_server + - closewrite_client + - readdeadline_server + - readdeadline_client + loop_control: + loop_var: config diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh new file mode 100755 index 0000000..e7a72e9 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +set -e + +cd "$( dirname "${BASH_SOURCE[0]}")" + +FILESDIR="$(pwd)"/files + +echo "[INFO] compile binary" +pushd .. >/dev/null +go build -o $FILESDIR/wireevaluator +popd >/dev/null + +if [ "$GENKEYS" == "" ]; then + echo "[INFO] GENKEYS environment variable not set, assumed to be valid" + exit 0 +fi + +echo "[INFO] gen ssh key" +ssh-keygen -f "$FILESDIR/wireevaluator.ssh_client_identity" -t ed25519 + +echo "[INFO] gen tls keys" + +cakey="$FILESDIR/wireevaluator.tls.ca.key" +cacrt="$FILESDIR/wireevaluator.tls.ca.crt" +hostprefix="$FILESDIR/wireevaluator.tls" + +openssl genrsa -out "$cakey" 4096 +openssl req -x509 -new -nodes -key "$cakey" -sha256 -days 1 -out "$cacrt" + +declare -a HOSTS +HOSTS+=("theserver") +HOSTS+=("theclient") + +for host in "${HOSTS[@]}"; do + key="${hostprefix}.${host}.key" + csr="${hostprefix}.${host}.csr" + crt="${hostprefix}.${host}.crt" + openssl genrsa -out "$key" 2048 + + ( + echo "." + echo "." + echo "." + echo "." + echo "." + echo $host + echo "." + echo "." + echo "." + echo "." + ) | openssl req -new -key "$key" -out "$csr" + + openssl x509 -req -in "$csr" -CA "$cacrt" -CAkey "$cakey" -CAcreateserial -out "$crt" -days 1 -sha256 + +done diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml new file mode 100644 index 0000000..71d9074 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml @@ -0,0 +1,54 @@ +--- + +- name: compile binary and any key files required + local_action: command ./gen_files.sh + +- name: Kill test binary + shell: "killall -9 wireevaluator || true" +- name: Deploy new binary + copy: + src: "files/wireevaluator" + dest: "/opt/wireevaluator" + mode: 0755 + +- set_fact: + wireevaluator_connect_ip: "{{hostvars['connect'].ansible_host}}" + wireevaluator_serve_ip: "{{hostvars['serve'].ansible_host}}" + +- name: Deploy config + template: + src: "templates/{{wireevaluator_transport}}.yml.j2" + dest: "/opt/wireevaluator.yml" + +- name: Deploy client identity + copy: + src: "files/wireevaluator.{{item}}" + dest: "/opt/wireevaluator.{{item}}" + mode: 0400 + with_items: + - ssh_client_identity + - ssh_client_identity.pub + - tls.ca.key + - tls.ca.crt + - tls.theserver.key + - tls.theserver.crt + - tls.theclient.key + - tls.theclient.crt + +- name: Setup server ssh client identity access + when: inventory_hostname == "serve" + block: + - authorized_key: + user: root + state: present + key: "{{ lookup('file', 'files/wireevaluator.ssh_client_identity.pub') }}" + key_options: 'command="/opt/wireevaluator -mode stdinserver -config /opt/wireevaluator.yml client1"' + - file: + state: directory + mode: 0700 + path: /tmp/wireevaluator_stdinserver + +- name: repeated test + include: internal_run_test_prepared_single.yml + with_sequence: start=1 end={{wireevaluator_repeat}} + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml new file mode 100644 index 0000000..4c535d9 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml @@ -0,0 +1,38 @@ +--- + +- debug: + msg: "run test transport={{wireevaluator_transport}} case={{wireevaluator_case}} repeatedly" + +- name: Run Server + when: inventory_hostname == "serve" + command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode serve -testcase {{wireevaluator_case}} + register: spawn_servers + async: 60 + poll: 0 + +- name: Run Client + when: inventory_hostname == "connect" + command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode connect -testcase {{wireevaluator_case}} + register: spawn_clients + async: 60 + poll: 0 + +- name: Wait for server shutdown + when: inventory_hostname == "serve" + async_status: + jid: "{{ spawn_servers.ansible_job_id}}" + delay: 0.5 + retries: 10 + +- name: Wait for client shutdown + when: inventory_hostname == "connect" + async_status: + jid: "{{ spawn_clients.ansible_job_id}}" + delay: 0.5 + retries: 10 + +- name: Wait for connections to die (TIME_WAIT conns) + command: sleep 4 + changed_when: false + + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example new file mode 100644 index 0000000..70cf1c7 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example @@ -0,0 +1,2 @@ +connect ansible_user=root ansible_host=192.168.122.128 wireevaluator_mode="connect" +serve ansible_user=root ansible_host=192.168.122.129 wireevaluator_mode="serve" diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 new file mode 100644 index 0000000..559f0d0 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 @@ -0,0 +1,13 @@ +connect: + type: ssh+stdinserver + host: {{wireevaluator_serve_ip}} + user: root + port: 22 + identity_file: /opt/wireevaluator.ssh_client_identity + options: # optional, default [], `-o` arguments passed to ssh + - "Compression=yes" +serve: + type: stdinserver + client_identities: + - "client1" + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 new file mode 100644 index 0000000..3aaa29e --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 @@ -0,0 +1,10 @@ +connect: + type: tcp + address: "{{wireevaluator_serve_ip}}:8888" +serve: + type: tcp + listen: ":8888" + clients: { + "{{wireevaluator_connect_ip}}" : "client1" + } + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 new file mode 100644 index 0000000..6fdaa94 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 @@ -0,0 +1,16 @@ +connect: + type: tls + address: "{{wireevaluator_serve_ip}}:8888" + ca: "/opt/wireevaluator.tls.ca.crt" + cert: "/opt/wireevaluator.tls.theclient.crt" + key: "/opt/wireevaluator.tls.theclient.key" + server_cn: "theserver" + +serve: + type: tls + listen: ":8888" + ca: "/opt/wireevaluator.tls.ca.crt" + cert: "/opt/wireevaluator.tls.theserver.crt" + key: "/opt/wireevaluator.tls.theserver.key" + client_cns: + - "theclient" diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go new file mode 100644 index 0000000..65ea9cd --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go @@ -0,0 +1,111 @@ +// a tool to test whether a given transport implements the timeoutconn.Wire interface +package main + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "os" + "path" + + netssh "github.com/problame/go-netssh" + "github.com/zrepl/yaml-config" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + transportconfig "github.com/zrepl/zrepl/transport/fromconfig" +) + +func noerror(err error) { + if err != nil { + panic(err) + } +} + +type Config struct { + Connect config.ConnectEnum + Serve config.ServeEnum +} + +var args struct { + mode string + configPath string + testCase string +} + +var conf Config + +type TestCase interface { + Client(wire transport.Wire) + Server(wire transport.Wire) +} + +func main() { + flag.StringVar(&args.mode, "mode", "", "connect|serve") + flag.StringVar(&args.configPath, "config", "", "config file path") + flag.StringVar(&args.testCase, "testcase", "", "") + flag.Parse() + + bytes, err := ioutil.ReadFile(args.configPath) + noerror(err) + err = yaml.UnmarshalStrict(bytes, &conf) + noerror(err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + global := &config.Global{ + Serve: &config.GlobalServe{ + StdinServer: &config.GlobalStdinServer{ + SockDir: "/tmp/wireevaluator_stdinserver", + }, + }, + } + + switch args.mode { + case "connect": + tc, err := getTestCase(args.testCase) + noerror(err) + connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect) + noerror(err) + wire, err := connecter.Connect(ctx) + noerror(err) + tc.Client(wire) + case "serve": + tc, err := getTestCase(args.testCase) + noerror(err) + lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve) + noerror(err) + l, err := lf() + noerror(err) + conn, err := l.Accept(ctx) + noerror(err) + tc.Server(conn) + case "stdinserver": + identity := flag.Arg(0) + unixaddr := path.Join(global.Serve.StdinServer.SockDir, identity) + err := netssh.Proxy(ctx, unixaddr) + if err == nil { + os.Exit(0) + } + panic(err) + default: + panic(fmt.Sprintf("unknown mode %q", args.mode)) + } + +} + +func getTestCase(tcName string) (TestCase, error) { + switch tcName { + case "closewrite_server": + return &CloseWrite{mode: CloseWriteServerSide}, nil + case "closewrite_client": + return &CloseWrite{mode: CloseWriteClientSide}, nil + case "readdeadline_client": + return &Deadlines{mode: DeadlineModeClientTimeout}, nil + case "readdeadline_server": + return &Deadlines{mode: DeadlineModeServerTimeout}, nil + default: + return nil, fmt.Errorf("unknown test case %q", tcName) + } +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go new file mode 100644 index 0000000..cc1907e --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go @@ -0,0 +1,110 @@ +package main + +import ( + "bytes" + "io" + "io/ioutil" + "log" + + "github.com/zrepl/zrepl/transport" +) + +type CloseWriteMode uint + +const ( + CloseWriteClientSide CloseWriteMode = 1 + iota + CloseWriteServerSide +) + +type CloseWrite struct { + mode CloseWriteMode +} + +// sent repeatedly +var closeWriteTestSendData = bytes.Repeat([]byte{0x23, 0x42}, 1<<24) +var closeWriteErrorMsg = []byte{0xb, 0xa, 0xd, 0xf, 0x0, 0x0, 0xd} + +func (m CloseWrite) Client(wire transport.Wire) { + switch m.mode { + case CloseWriteClientSide: + m.receiver(wire) + case CloseWriteServerSide: + m.sender(wire) + default: + panic(m.mode) + } +} + +func (m CloseWrite) Server(wire transport.Wire) { + switch m.mode { + case CloseWriteClientSide: + m.sender(wire) + case CloseWriteServerSide: + m.receiver(wire) + default: + panic(m.mode) + } +} + +func (CloseWrite) sender(wire transport.Wire) { + defer func() { + closeErr := wire.Close() + log.Printf("closeErr=%T %s", closeErr, closeErr) + }() + + type opResult struct { + err error + } + writeDone := make(chan struct{}, 1) + go func() { + close(writeDone) + for { + _, err := wire.Write(closeWriteTestSendData) + if err != nil { + return + } + } + }() + + defer func() { + <-writeDone + }() + + var respBuf bytes.Buffer + _, err := io.Copy(&respBuf, wire) + if err != nil { + log.Fatalf("should have received io.EOF, which is masked by io.Copy, got: %s", err) + } + if !bytes.Equal(respBuf.Bytes(), closeWriteErrorMsg) { + log.Fatalf("did not receive error message, got response with len %v:\n%v", respBuf.Len(), respBuf.Bytes()) + } + +} + +func (CloseWrite) receiver(wire transport.Wire) { + + // consume half the test data, then detect an error, send it and CloseWrite + + r := io.LimitReader(wire, int64(5 * len(closeWriteTestSendData)/3)) + _, err := io.Copy(ioutil.Discard, r) + noerror(err) + + var errBuf bytes.Buffer + errBuf.Write(closeWriteErrorMsg) + _, err = io.Copy(wire, &errBuf) + noerror(err) + + err = wire.CloseWrite() + noerror(err) + + // drain wire, as documented in transport.Wire, this is the only way we know the client closed the conn + _, err = io.Copy(ioutil.Discard, wire) + if err != nil { + // io.Copy masks io.EOF to nil, and we expect io.EOF from the client's Close() call + log.Panicf("unexpected error returned from reading conn: %s", err) + } + + closeErr := wire.Close() + log.Printf("closeErr=%T %s", closeErr, closeErr) + +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go new file mode 100644 index 0000000..397e9d5 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go @@ -0,0 +1,138 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "log" + "net" + "time" + + "github.com/zrepl/zrepl/transport" +) + +type DeadlineMode uint + +const ( + DeadlineModeClientTimeout DeadlineMode = 1 + iota + DeadlineModeServerTimeout +) + +type Deadlines struct { + mode DeadlineMode +} + +func (d Deadlines) Client(wire transport.Wire) { + switch d.mode { + case DeadlineModeClientTimeout: + d.sleepThenSend(wire) + case DeadlineModeServerTimeout: + d.sendThenRead(wire) + default: + panic(d.mode) + } +} + +func (d Deadlines) Server(wire transport.Wire) { + switch d.mode { + case DeadlineModeClientTimeout: + d.sendThenRead(wire) + case DeadlineModeServerTimeout: + d.sleepThenSend(wire) + default: + panic(d.mode) + } +} + +var deadlinesTimeout = 1 * time.Second + +func (d Deadlines) sleepThenSend(wire transport.Wire) { + defer wire.Close() + + log.Print("sleepThenSend") + + // exceed timeout of peer (do not respond to their hi msg) + time.Sleep(3 * deadlinesTimeout) + // expect that the client has hung up on us by now + err := d.sendMsg(wire, "hi") + log.Printf("err=%s", err) + log.Printf("err=%#v", err) + if err == nil { + log.Panic("no error") + } + if _, ok := err.(net.Error); !ok { + log.Panic("not a net error") + } + +} + +func (d Deadlines) sendThenRead(wire transport.Wire) { + + log.Print("sendThenRead") + + err := d.sendMsg(wire, "hi") + noerror(err) + + err = wire.SetReadDeadline(time.Now().Add(deadlinesTimeout)) + noerror(err) + + m, err := d.recvMsg(wire) + log.Printf("m=%q", m) + log.Printf("err=%s", err) + log.Printf("err=%#v", err) + + // close asap so that the peer get's a 'connection reset by peer' error or similar + closeErr := wire.Close() + if closeErr != nil { + panic(closeErr) + } + + var neterr net.Error + var ok bool + if err == nil { + goto unexpErr // works for nil, too + } + neterr, ok = err.(net.Error) + if !ok { + log.Println("not a net error") + goto unexpErr + } + if !neterr.Timeout() { + log.Println("not a timeout") + } + + return + +unexpErr: + panic(fmt.Sprintf("sendThenRead: client should have hung up but got error %T %s", err, err)) +} + +const deadlinesMsgLen = 40 + +func (d Deadlines) sendMsg(wire transport.Wire, msg string) error { + if len(msg) > deadlinesMsgLen { + panic(len(msg)) + } + var buf [deadlinesMsgLen]byte + copy(buf[:], []byte(msg)) + n, err := wire.Write(buf[:]) + if err != nil { + return err + } + if n != len(buf) { + panic("short write not allowed") + } + return nil +} + +func (d Deadlines) recvMsg(wire transport.Wire) (string, error) { + + var buf bytes.Buffer + r := io.LimitReader(wire, deadlinesMsgLen) + _, err := io.Copy(&buf, r) + if err != nil { + return "", err + } + return buf.String(), nil + +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn.go b/rpc/dataconn/timeoutconn/timeoutconn.go new file mode 100644 index 0000000..9e0a3bf --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn.go @@ -0,0 +1,288 @@ +// package timeoutconn wraps a Wire to provide idle timeouts +// based on Set{Read,Write}Deadline. +// Additionally, it exports abstractions for vectored I/O. +package timeoutconn + +// NOTE +// Readv and Writev are not split-off into a separate package +// because we use raw syscalls, bypassing Conn's Read / Write methods. + +import ( + "errors" + "io" + "net" + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +type Wire interface { + net.Conn + // A call to CloseWrite indicates that no further Write calls will be made to Wire. + // The implementation must return an error in case of Write calls after CloseWrite. + // On the peer's side, after it read all data written to Wire prior to the call to + // CloseWrite on our side, the peer's Read calls must return io.EOF. + // CloseWrite must not affect the read-direction of Wire: specifically, the + // peer must continue to be able to send, and our side must continue be + // able to receive data over Wire. + // + // Note that CloseWrite may (and most likely will) return sooner than the + // peer having received all data written to Wire prior to CloseWrite. + // Note further that buffering happening in the network stacks on either side + // mandates an explicit acknowledgement from the peer that the connection may + // be fully shut down: If we call Close without such acknowledgement, any data + // from peer to us that was already in flight may cause connection resets to + // be sent from us to the peer via the specific transport protocol. Those + // resets (e.g. RST frames) may erase all connection context on the peer, + // including data in its receive buffers. Thus, those resets are in race with + // a) transmission of data written prior to CloseWrite and + // b) the peer application reading from those buffers. + // + // The WaitForPeerClose method can be used to wait for connection termination, + // iff the implementation supports it. If it does not, the only reliable way + // to wait for a peer to have read all data from Wire (until io.EOF), is to + // expect it to close the wire at that point as well, and to drain Wire until + // we also read io.EOF. + CloseWrite() error + + // Wait for the peer to close the connection. + // No data that could otherwise be Read is lost as a consequence of this call. + // The use case for this API is abortive connection shutdown. + // To provide any value over draining Wire using io.Read, an implementation + // will likely use out-of-bounds messaging mechanisms. + // TODO WaitForPeerClose() (supported bool, err error) +} + +type Conn struct { + Wire + renewDeadlinesDisabled int32 + idleTimeout time.Duration +} + +func Wrap(conn Wire, idleTimeout time.Duration) Conn { + return Conn{Wire: conn, idleTimeout: idleTimeout} +} + +// DisableTimeouts disables the idle timeout behavior provided by this package. +// Existing deadlines are cleared iff the call is the first call to this method. +func (c *Conn) DisableTimeouts() error { + if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) { + return c.SetDeadline(time.Time{}) + } + return nil +} + +func (c *Conn) renewReadDeadline() error { + if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + return nil + } + return c.SetReadDeadline(time.Now().Add(c.idleTimeout)) +} + +func (c *Conn) renewWriteDeadline() error { + if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + return nil + } + return c.SetWriteDeadline(time.Now().Add(c.idleTimeout)) +} + +func (c Conn) Read(p []byte) (n int, err error) { + n = 0 + err = nil +restart: + if err := c.renewReadDeadline(); err != nil { + return n, err + } + var nCurRead int + nCurRead, err = c.Wire.Read(p[n:len(p)]) + n += nCurRead + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 { + err = nil + goto restart + } + return n, err +} + +func (c Conn) Write(p []byte) (n int, err error) { + n = 0 +restart: + if err := c.renewWriteDeadline(); err != nil { + return n, err + } + var nCurWrite int + nCurWrite, err = c.Wire.Write(p[n:len(p)]) + n += nCurWrite + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 { + err = nil + goto restart + } + return n, err +} + +// Writes the given buffers to Conn, following the sematincs of io.Copy, +// but is guaranteed to use the writev system call if the wrapped Wire +// support it. +// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers). +func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) { + n = 0 +restart: + if err := c.renewWriteDeadline(); err != nil { + return n, err + } + var nCurWrite int64 + nCurWrite, err = io.Copy(c.Wire, &bufs) + n += nCurWrite + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 { + err = nil + goto restart + } + return n, err +} + +var SyscallConnNotSupported = errors.New("SyscallConn not supported") + +// The interface that must be implemented for vectored I/O support. +// If the wrapped Wire does not implement it, a less efficient +// fallback implementation is used. +// Rest assured that Go's *net.TCPConn implements this interface. +type SyscallConner interface { + // The sentinel error value SyscallConnNotSupported can be returned + // if the support for SyscallConn depends on runtime conditions and + // that runtime condition is not met. + SyscallConn() (syscall.RawConn, error) +} + +var _ SyscallConner = (*net.TCPConn)(nil) + +func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) { + vecs = make([]syscall.Iovec, 0, len(buffers)) + for i := range buffers { + totalLen += int64(len(buffers[i])) + if len(buffers[i]) == 0 { + continue + } + vecs = append(vecs, syscall.Iovec{ + Base: &buffers[i][0], + Len: uint64(len(buffers[i])), + }) + } + return totalLen, vecs +} + +// Reads the given buffers full: +// Think of io.ReadvFull, but for net.Buffers + using the readv syscall. +// +// If the underlying Wire is not a SyscallConner, a fallback +// ipmlementation based on repeated Conn.Read invocations is used. +// +// If the connection returned io.EOF, the number of bytes up ritten until +// then + io.EOF is returned. This behavior is different to io.ReadFull +// which returns io.ErrUnexpectedEOF. +func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) { + totalLen, iovecs := buildIovecs(buffers) + if debugReadvNoShortReadsAssertEnable { + defer debugReadvNoShortReadsAssert(totalLen, n, err) + } + scc, ok := c.Wire.(SyscallConner) + if !ok { + return c.readvFallback(buffers) + } + raw, err := scc.SyscallConn() + if err == SyscallConnNotSupported { + return c.readvFallback(buffers) + } + if err != nil { + return 0, err + } + n, err = c.readv(raw, iovecs) + return +} + +func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) { + buffers := [][]byte(nbuffers) + for i := range buffers { + curBuf := buffers[i] + inner: + for len(curBuf) > 0 { + if err := c.renewReadDeadline(); err != nil { + return n, err + } + var oneN int + oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING + curBuf = curBuf[oneN:] + n += int64(oneN) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 { + continue inner + } + return n, err + } + } + } + return n, nil +} + +func (c Conn) readv(rawConn syscall.RawConn, iovecs []syscall.Iovec) (n int64, err error) { + for len(iovecs) > 0 { + if err := c.renewReadDeadline(); err != nil { + return n, err + } + oneN, oneErr := c.doOneReadv(rawConn, &iovecs) + n += oneN + if netErr, ok := oneErr.(net.Error); ok && netErr.Timeout() && oneN > 0 { // TODO likely not working + continue + } else if oneErr == nil && oneN > 0 { + continue + } else { + return n, oneErr + } + } + return n, nil +} + +func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) { + rawReadErr := rawConn.Read(func(fd uintptr) (done bool) { + // iovecs, n and err must not be shadowed! + thisReadN, _, errno := syscall.Syscall( + syscall.SYS_READV, + fd, + uintptr(unsafe.Pointer(&(*iovecs)[0])), + uintptr(len(*iovecs)), + ) + if thisReadN == ^uintptr(0) { + if errno == syscall.EAGAIN { + return false + } + err = syscall.Errno(errno) + return true + } + if int(thisReadN) < 0 { + panic("unexpected return value") + } + n += int64(thisReadN) + // shift iovecs forward + for left := int64(thisReadN); left > 0; { + curVecNewLength := int64((*iovecs)[0].Len) - left // TODO assert conversion + if curVecNewLength <= 0 { + left -= int64((*iovecs)[0].Len) + *iovecs = (*iovecs)[1:] + } else { + (*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left))) + (*iovecs)[0].Len = uint64(curVecNewLength) + break // inner + } + } + if thisReadN == 0 { + err = io.EOF + return true + } + return true + }) + + if rawReadErr != nil { + err = rawReadErr + } + + return n, err +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn_debug.go b/rpc/dataconn/timeoutconn/timeoutconn_debug.go new file mode 100644 index 0000000..d5a7cb3 --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn_debug.go @@ -0,0 +1,19 @@ +package timeoutconn + +import ( + "fmt" + "io" +) + +const debugReadvNoShortReadsAssertEnable = false + +func debugReadvNoShortReadsAssert(expectedLen, returnedLen int64, returnedErr error) { + readShort := expectedLen != returnedLen + if !readShort { + return + } + if returnedErr != io.EOF { + return + } + panic(fmt.Sprintf("ReadvFull short and error is not EOF%v\n", returnedErr)) +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn_test.go b/rpc/dataconn/timeoutconn/timeoutconn_test.go new file mode 100644 index 0000000..708bba2 --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn_test.go @@ -0,0 +1,177 @@ +package timeoutconn + +import ( + "bytes" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zrepl/zrepl/util/socketpair" +) + +func TestReadTimeout(t *testing.T) { + + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var buf bytes.Buffer + buf.WriteString("tooktoolong") + time.Sleep(500 * time.Millisecond) + _, err := io.Copy(a, &buf) + require.NoError(t, err) + }() + + go func() { + defer wg.Done() + conn := Wrap(b, 100*time.Millisecond) + buf := [4]byte{} // shorter than message put on wire + n, err := conn.Read(buf[:]) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) + }() + + wg.Wait() +} + +type writeBlockConn struct { + net.Conn + blockTime time.Duration +} + +func (c writeBlockConn) Write(p []byte) (int, error) { + time.Sleep(c.blockTime) + return c.Conn.Write(p) +} + +func (c writeBlockConn) CloseWrite() error { + return c.Conn.Close() +} + +func TestWriteTimeout(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + var buf bytes.Buffer + buf.WriteString("message") + blockConn := writeBlockConn{a, 500 * time.Millisecond} + conn := Wrap(blockConn, 100*time.Millisecond) + n, err := conn.Write(buf.Bytes()) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) +} + +func TestNoPartialReadsDueToDeadline(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + a.Write([]byte{1, 2, 3, 4, 5}) + // sleep to provoke a partial read in the consumer goroutine + time.Sleep(50 * time.Millisecond) + a.Write([]byte{6, 7, 8, 9, 10}) + }() + + go func() { + defer wg.Done() + bc := Wrap(b, 100*time.Millisecond) + var buf bytes.Buffer + beginRead := time.Now() + // io.Copy will encounter a partial read, then wait ~50ms until the other 5 bytes are written + // It is still going to fail with deadline err because it expects EOF + n, err := io.Copy(&buf, bc) + readDuration := time.Now().Sub(beginRead) + t.Logf("read duration=%s", readDuration) + t.Logf("recv done n=%v err=%v", n, err) + t.Logf("buf=%v", buf.Bytes()) + neterr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, neterr.Timeout()) + + assert.Equal(t, int64(10), n) + assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, buf.Bytes()) + // 50ms for the second read, 100ms after that one for the deadline + // allow for some jitter + assert.True(t, readDuration > 140*time.Millisecond) + assert.True(t, readDuration < 160*time.Millisecond) + }() + + wg.Wait() +} + +type partialWriteMockConn struct { + net.Conn // to satisfy interface + buf bytes.Buffer + writeDuration time.Duration + returnAfterBytesWritten int +} + +func newPartialWriteMockConn(writeDuration time.Duration, returnAfterBytesWritten int) *partialWriteMockConn { + return &partialWriteMockConn{ + writeDuration: writeDuration, + returnAfterBytesWritten: returnAfterBytesWritten, + } +} + +func (c *partialWriteMockConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDuration) + consumeBytes := len(p) + if consumeBytes > c.returnAfterBytesWritten { + consumeBytes = c.returnAfterBytesWritten + } + n, err := c.buf.Write(p[0:consumeBytes]) + if err != nil || n != consumeBytes { + panic("bytes.Buffer behaves unexpectedly") + } + return n, nil +} + +func TestPartialWriteMockConn(t *testing.T) { + mc := newPartialWriteMockConn(100*time.Millisecond, 5) + buf := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + begin := time.Now() + n, err := mc.Write(buf[:]) + duration := time.Now().Sub(begin) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.True(t, duration > 100*time.Millisecond) + assert.True(t, duration < 150*time.Millisecond) +} + +func TestNoPartialWritesDueToDeadline(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + var buf bytes.Buffer + buf.WriteString("message") + blockConn := writeBlockConn{a, 150 * time.Millisecond} + conn := Wrap(blockConn, 100*time.Millisecond) + n, err := conn.Write(buf.Bytes()) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) +} diff --git a/rpc/grpcclientidentity/authlistener_grpc_adaptor.go b/rpc/grpcclientidentity/authlistener_grpc_adaptor.go new file mode 100644 index 0000000..accd124 --- /dev/null +++ b/rpc/grpcclientidentity/authlistener_grpc_adaptor.go @@ -0,0 +1,118 @@ +// Package grpcclientidentity makes the client identity +// provided by github.com/zrepl/zrepl/daemon/transport/serve.{AuthenticatedListener,AuthConn} +// available to gRPC service handlers. +// +// This goal is achieved through the combination of custom gRPC transport credentials and two interceptors +// (i.e. middleware). +// +// For gRPC clients, the TransportCredentials + Dialer can be used to construct a gRPC client (grpc.ClientConn) +// that uses a github.com/zrepl/zrepl/daemon/transport/connect.Connecter to connect to a server. +// +// The adaptors exposed by this package must be used together, and panic if they are not. +// See package grpchelper for a more restrictive but safe example on how the adaptors should be composed. +package grpcclientidentity + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/transport" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +type Logger = logger.Logger + +type GRPCDialFunction = func(string, time.Duration) (net.Conn, error) + +func NewDialer(logger Logger, connecter transport.Connecter) GRPCDialFunction { + return func(s string, duration time.Duration) (conn net.Conn, e error) { + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + nc, err := connecter.Connect(ctx) + // TODO find better place (callback from gRPC?) where to log errors + // we want the users to know, though + if err != nil { + logger.WithError(err).Error("cannot connect") + } + return nc, err + } +} + +type authConnAuthType struct { + clientIdentity string +} + +func (authConnAuthType) AuthType() string { + return "AuthConn" +} + +type connecterAuthType struct{} + +func (connecterAuthType) AuthType() string { + return "connecter" +} + +type transportCredentials struct { + logger Logger +} + +// Use on both sides as ServerOption or ClientOption. +func NewTransportCredentials(log Logger) credentials.TransportCredentials { + if log == nil { + log = logger.NewNullLogger() + } + return &transportCredentials{log} +} + +func (c *transportCredentials) ClientHandshake(ctx context.Context, s string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.logger.WithField("url", s).WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ClientHandshake") + // do nothing, client credential is only for WithInsecure warning to go away + // the authentication is done by the connecter + return rawConn, &connecterAuthType{}, nil +} + +func (c *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.logger.WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ServerHandshake") + authConn, ok := rawConn.(*transport.AuthConn) + if !ok { + panic(fmt.Sprintf("NewTransportCredentials must be used with a listener that returns *transport.AuthConn, got %T", rawConn)) + } + return rawConn, &authConnAuthType{authConn.ClientIdentity()}, nil +} + +func (*transportCredentials) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} // TODO +} + +func (t *transportCredentials) Clone() credentials.TransportCredentials { + var x = *t + return &x +} + +func (*transportCredentials) OverrideServerName(string) error { + panic("not implemented") +} + +func NewInterceptors(logger Logger, clientIdentityKey interface{}) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) { + unary = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + logger.WithField("fullMethod", info.FullMethod).Debug("request") + p, ok := peer.FromContext(ctx) + if !ok { + panic("peer.FromContext expected to return a peer in grpc.UnaryServerInterceptor") + } + logger.WithField("peer", fmt.Sprintf("%v", p)).Debug("peer") + a, ok := p.AuthInfo.(*authConnAuthType) + if !ok { + panic(fmt.Sprintf("NewInterceptors must be used in combination with grpc.NewTransportCredentials, but got auth type %T", p.AuthInfo)) + } + ctx = context.WithValue(ctx, clientIdentityKey, a.clientIdentity) + return handler(ctx, req) + } + stream = nil + return +} diff --git a/rpc/grpcclientidentity/example/grpcauth.proto b/rpc/grpcclientidentity/example/grpcauth.proto new file mode 100644 index 0000000..5adf606 --- /dev/null +++ b/rpc/grpcclientidentity/example/grpcauth.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package pdu; + + +service Greeter { + rpc Greet(GreetRequest) returns (GreetResponse) {} +} + +message GreetRequest { + string name = 1; +} + +message GreetResponse { + string msg = 1; +} \ No newline at end of file diff --git a/rpc/grpcclientidentity/example/main.go b/rpc/grpcclientidentity/example/main.go new file mode 100644 index 0000000..3813683 --- /dev/null +++ b/rpc/grpcclientidentity/example/main.go @@ -0,0 +1,107 @@ +// This package demonstrates how the grpcclientidentity package can be used +// to set up a gRPC greeter service. +package main + +import ( + "context" + "flag" + "fmt" + "os" + "time" + + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/example/pdu" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper" + "github.com/zrepl/zrepl/transport/tcp" +) + +var args struct { + mode string +} + +var log = logger.NewStderrDebugLogger() + +func main() { + flag.StringVar(&args.mode, "mode", "", "client|server") + flag.Parse() + + switch args.mode { + case "client": + client() + case "server": + server() + default: + log.Printf("unknown mode %q") + os.Exit(1) + } +} + +func onErr(err error, format string, args ...interface{}) { + log.WithError(err).Error(fmt.Sprintf("%s: %s", fmt.Sprintf(format, args...), err)) + os.Exit(1) +} + +func client() { + cn, err := tcp.TCPConnecterFromConfig(&config.TCPConnect{ + ConnectCommon: config.ConnectCommon{ + Type: "tcp", + }, + Address: "127.0.0.1:8080", + DialTimeout: 10 * time.Second, + }) + if err != nil { + onErr(err, "build connecter error") + } + + clientConn := grpchelper.ClientConn(cn, log) + defer clientConn.Close() + + // normal usage from here on + + client := pdu.NewGreeterClient(clientConn) + resp, err := client.Greet(context.Background(), &pdu.GreetRequest{Name: "somethingimadeup"}) + if err != nil { + onErr(err, "RPC error") + } + + fmt.Printf("got response:\n\t%s\n", resp.GetMsg()) +} + +const clientIdentityKey = "clientIdentity" + +func server() { + authListenerFactory, err := tcp.TCPListenerFactoryFromConfig(nil, &config.TCPServe{ + ServeCommon: config.ServeCommon{ + Type: "tcp", + }, + Listen: "127.0.0.1:8080", + Clients: map[string]string{ + "127.0.0.1": "localclient", + "::1": "localclient", + }, + }) + if err != nil { + onErr(err, "cannot build listener factory") + } + + log := logger.NewStderrDebugLogger() + + srv, serve, err := grpchelper.NewServer(authListenerFactory, clientIdentityKey, log) + svc := &greeter{"hello "} + pdu.RegisterGreeterServer(srv, svc) + + if err := serve(); err != nil { + onErr(err, "error serving") + } +} + +type greeter struct { + prepend string +} + +func (g *greeter) Greet(ctx context.Context, r *pdu.GreetRequest) (*pdu.GreetResponse, error) { + ci, _ := ctx.Value(clientIdentityKey).(string) + log.WithField("clientIdentity", ci).Info("Greet() request") // show that we got the client identity + return &pdu.GreetResponse{Msg: fmt.Sprintf("%s%s (clientIdentity=%q)", g.prepend, r.GetName(), ci)}, nil +} diff --git a/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go b/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go new file mode 100644 index 0000000..0fbf2cb --- /dev/null +++ b/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go @@ -0,0 +1,193 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: grpcauth.proto + +package pdu + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type GreetRequest struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GreetRequest) Reset() { *m = GreetRequest{} } +func (m *GreetRequest) String() string { return proto.CompactTextString(m) } +func (*GreetRequest) ProtoMessage() {} +func (*GreetRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_1dfba7be0cf69353, []int{0} +} + +func (m *GreetRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GreetRequest.Unmarshal(m, b) +} +func (m *GreetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GreetRequest.Marshal(b, m, deterministic) +} +func (m *GreetRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GreetRequest.Merge(m, src) +} +func (m *GreetRequest) XXX_Size() int { + return xxx_messageInfo_GreetRequest.Size(m) +} +func (m *GreetRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GreetRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GreetRequest proto.InternalMessageInfo + +func (m *GreetRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type GreetResponse struct { + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GreetResponse) Reset() { *m = GreetResponse{} } +func (m *GreetResponse) String() string { return proto.CompactTextString(m) } +func (*GreetResponse) ProtoMessage() {} +func (*GreetResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_1dfba7be0cf69353, []int{1} +} + +func (m *GreetResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GreetResponse.Unmarshal(m, b) +} +func (m *GreetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GreetResponse.Marshal(b, m, deterministic) +} +func (m *GreetResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GreetResponse.Merge(m, src) +} +func (m *GreetResponse) XXX_Size() int { + return xxx_messageInfo_GreetResponse.Size(m) +} +func (m *GreetResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GreetResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GreetResponse proto.InternalMessageInfo + +func (m *GreetResponse) GetMsg() string { + if m != nil { + return m.Msg + } + return "" +} + +func init() { + proto.RegisterType((*GreetRequest)(nil), "pdu.GreetRequest") + proto.RegisterType((*GreetResponse)(nil), "pdu.GreetResponse") +} + +func init() { proto.RegisterFile("grpcauth.proto", fileDescriptor_1dfba7be0cf69353) } + +var fileDescriptor_1dfba7be0cf69353 = []byte{ + // 137 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x2a, 0x48, + 0x4e, 0x2c, 0x2d, 0xc9, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2e, 0x48, 0x29, 0x55, + 0x52, 0xe2, 0xe2, 0x71, 0x2f, 0x4a, 0x4d, 0x2d, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, + 0x12, 0xe2, 0x62, 0xc9, 0x4b, 0xcc, 0x4d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3, + 0x95, 0x14, 0xb9, 0x78, 0xa1, 0x6a, 0x8a, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0x04, 0xb8, 0x98, + 0x73, 0x8b, 0xd3, 0xa1, 0x6a, 0x40, 0x4c, 0x23, 0x6b, 0x2e, 0x76, 0xb0, 0x92, 0xd4, 0x22, 0x21, + 0x03, 0x2e, 0x56, 0x30, 0x53, 0x48, 0x50, 0xaf, 0x20, 0xa5, 0x54, 0x0f, 0xd9, 0x74, 0x29, 0x21, + 0x64, 0x21, 0x88, 0x61, 0x4a, 0x0c, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0xa8, 0x53, 0x2f, 0x4c, 0xa1, 0x00, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// GreeterClient is the client API for Greeter service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type GreeterClient interface { + Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) +} + +type greeterClient struct { + cc *grpc.ClientConn +} + +func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { + return &greeterClient{cc} +} + +func (c *greeterClient) Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) { + out := new(GreetResponse) + err := c.cc.Invoke(ctx, "/pdu.Greeter/Greet", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// GreeterServer is the server API for Greeter service. +type GreeterServer interface { + Greet(context.Context, *GreetRequest) (*GreetResponse, error) +} + +func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { + s.RegisterService(&_Greeter_serviceDesc, srv) +} + +func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GreetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GreeterServer).Greet(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pdu.Greeter/Greet", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GreeterServer).Greet(ctx, req.(*GreetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Greeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "pdu.Greeter", + HandlerType: (*GreeterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Greet", + Handler: _Greeter_Greet_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "grpcauth.proto", +} diff --git a/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go b/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go new file mode 100644 index 0000000..7e46681 --- /dev/null +++ b/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go @@ -0,0 +1,76 @@ +// Package grpchelper wraps the adaptors implemented by package grpcclientidentity into a less flexible API +// which, however, ensures that the individual adaptor primitive's expectations are met and hence do not panic. +package grpchelper + +import ( + "context" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/grpcclientidentity" + "github.com/zrepl/zrepl/rpc/netadaptor" + "github.com/zrepl/zrepl/transport" +) + +// The following constants are relevant for interoperability. +// We use the same values for client & server, because zrepl is more +// symmetrical ("one source, one sink") instead of the typical +// gRPC scenario ("many clients, single server") +const ( + StartKeepalivesAfterInactivityDuration = 5 * time.Second + KeepalivePeerTimeout = 10 * time.Second +) + +type Logger = logger.Logger + +// ClientConn is an easy-to-use wrapper around the Dialer and TransportCredentials interface +// to produce a grpc.ClientConn +func ClientConn(cn transport.Connecter, log Logger) *grpc.ClientConn { + ka := grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: StartKeepalivesAfterInactivityDuration, + Timeout: KeepalivePeerTimeout, + PermitWithoutStream: true, + }) + dialerOption := grpc.WithDialer(grpcclientidentity.NewDialer(log, cn)) + cred := grpc.WithTransportCredentials(grpcclientidentity.NewTransportCredentials(log)) + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, "doesn't matter done by dialer", dialerOption, cred, ka) + if err != nil { + log.WithError(err).Error("cannot create gRPC client conn (non-blocking)") + // It's ok to panic here: the we call grpc.DialContext without the + // (grpc.WithBlock) dial option, and at the time of writing, the grpc + // docs state that no connection attempt is made in that case. + // Hence, any error that occurs is due to DialOptions or similar, + // and thus indicative of an implementation error. + panic(err) + } + return cc +} + +// NewServer is a convenience interface around the TransportCredentials and Interceptors interface. +func NewServer(authListenerFactory transport.AuthenticatedListenerFactory, clientIdentityKey interface{}, logger grpcclientidentity.Logger) (srv *grpc.Server, serve func() error, err error) { + ka := grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: StartKeepalivesAfterInactivityDuration, + Timeout: KeepalivePeerTimeout, + }) + tcs := grpcclientidentity.NewTransportCredentials(logger) + unary, stream := grpcclientidentity.NewInterceptors(logger, clientIdentityKey) + srv = grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream), ka) + + serve = func() error { + l, err := authListenerFactory() + if err != nil { + return err + } + if err := srv.Serve(netadaptor.New(l, logger)); err != nil { + return err + } + return nil + } + + return srv, serve, nil +} diff --git a/rpc/netadaptor/authlistener_netlistener_adaptor.go b/rpc/netadaptor/authlistener_netlistener_adaptor.go new file mode 100644 index 0000000..cc9ae70 --- /dev/null +++ b/rpc/netadaptor/authlistener_netlistener_adaptor.go @@ -0,0 +1,102 @@ +// Package netadaptor implements an adaptor from +// transport.AuthenticatedListener to net.Listener. +// +// In contrast to transport.AuthenticatedListener, +// net.Listener is commonly expected (e.g. by net/http.Server.Serve), +// to return errors that fulfill the Temporary interface: +// interface Temporary { Temporary() bool } +// Common behavior of packages consuming net.Listener is to return +// from the serve-loop if an error returned by Accept is not Temporary, +// i.e., does not implement the interface or is !net.Error.Temporary(). +// +// The zrepl transport infrastructure was written with the +// idea that Accept() may return any kind of error, and the consumer +// would just log the error and continue calling Accept(). +// We have to adapt these listeners' behavior to the expectations +// of net/http.Server. +// +// Hence, Listener does not return an error at all but blocks the +// caller of Accept() until we get a (successfully authenticated) +// connection without errors from the transport. +// Accept errors returned from the transport are logged as errors +// to the logger passed on initialization. +package netadaptor + +import ( + "context" + "fmt" + "github.com/zrepl/zrepl/logger" + "net" + "github.com/zrepl/zrepl/transport" +) + +type Logger = logger.Logger + +type acceptReq struct { + callback chan net.Conn +} + +type Listener struct { + al transport.AuthenticatedListener + logger Logger + accepts chan acceptReq + stop chan struct{} +} + +// Consume the given authListener and wrap it as a *Listener, which implements net.Listener. +// The returned net.Listener must be closed. +// The wrapped authListener is closed when the returned net.Listener is closed. +func New(authListener transport.AuthenticatedListener, l Logger) *Listener { + if l == nil { + l = logger.NewNullLogger() + } + a := &Listener{ + al: authListener, + logger: l, + accepts: make(chan acceptReq), + stop: make(chan struct{}), + } + go a.handleAccept() + return a +} + +// The returned net.Conn is guaranteed to be *transport.AuthConn, i.e., the type of connection +// returned by the wrapped transport.AuthenticatedListener. +func (a Listener) Accept() (net.Conn, error) { + req := acceptReq{make(chan net.Conn, 1)} + a.accepts <- req + conn := <-req.callback + return conn, nil +} + +func (a Listener) handleAccept() { + for { + select { + case <-a.stop: + a.logger.Debug("handleAccept stop accepting") + return + case req := <-a.accepts: + for { + a.logger.Debug("accept authListener") + authConn, err := a.al.Accept(context.Background()) + if err != nil { + a.logger.WithError(err).Error("accept error") + continue + } + a.logger.WithField("type", fmt.Sprintf("%T", authConn)). + Debug("accept complete") + req.callback <- authConn + break + } + } + } +} + +func (a Listener) Addr() net.Addr { + return a.al.Addr() +} + +func (a Listener) Close() error { + close(a.stop) + return a.al.Close() +} diff --git a/rpc/rpc_client.go b/rpc/rpc_client.go new file mode 100644 index 0000000..efcaf98 --- /dev/null +++ b/rpc/rpc_client.go @@ -0,0 +1,106 @@ +package rpc + +import ( + "context" + "net" + "time" + + "google.golang.org/grpc" + + "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper" + "github.com/zrepl/zrepl/rpc/versionhandshake" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/envconst" + "github.com/zrepl/zrepl/zfs" +) + +// Client implements the active side of a replication setup. +// It satisfies the Endpoint, Sender and Receiver interface defined by package replication. +type Client struct { + dataClient *dataconn.Client + controlClient pdu.ReplicationClient // this the grpc client instance, see constructor + controlConn *grpc.ClientConn + loggers Loggers +} + +var _ replication.Endpoint = &Client{} +var _ replication.Sender = &Client{} +var _ replication.Receiver = &Client{} + +type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) + +// config must be validated, NewClient will panic if it is not valid +func NewClient(cn transport.Connecter, loggers Loggers) *Client { + + cn = versionhandshake.Connecter(cn, envconst.Duration("ZREPL_RPC_CLIENT_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) + + muxedConnecter := mux(cn) + + c := &Client{ + loggers: loggers, + } + grpcConn := grpchelper.ClientConn(muxedConnecter.control, loggers.Control) + + go func() { + for { + state := grpcConn.GetState() + loggers.General.WithField("grpc_state", state.String()).Debug("grpc state change") + grpcConn.WaitForStateChange(context.TODO(), state) + } + }() + c.controlClient = pdu.NewReplicationClient(grpcConn) + c.controlConn = grpcConn + + c.dataClient = dataconn.NewClient(muxedConnecter.data, loggers.Data) + return c +} + +func (c *Client) Close() { + if err := c.controlConn.Close(); err != nil { + c.loggers.General.WithError(err).Error("cannot cloe control connection") + } + // TODO c.dataClient should have Close() +} + +// callers must ensure that the returned io.ReadCloser is closed +// TODO expose dataClient interface to the outside world +func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + // TODO the returned sendStream may return a read error created by the remote side + res, streamCopier, err := c.dataClient.ReqSend(ctx, r) + if err != nil { + return nil, nil, nil + } + if streamCopier == nil { + return res, nil, nil + } + + return res, streamCopier, nil + +} + +func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { + return c.dataClient.ReqRecv(ctx, req, streamCopier) +} + +func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + return c.controlClient.ListFilesystems(ctx, in) +} + +func (c *Client) ListFilesystemVersions(ctx context.Context, in *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + return c.controlClient.ListFilesystemVersions(ctx, in) +} + +func (c *Client) DestroySnapshots(ctx context.Context, in *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { + return c.controlClient.DestroySnapshots(ctx, in) +} + +func (c *Client) ReplicationCursor(ctx context.Context, in *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { + return c.controlClient.ReplicationCursor(ctx, in) +} + +func (c *Client) ResetConnectBackoff() { + c.controlConn.ResetConnectBackoff() +} diff --git a/rpc/rpc_debug.go b/rpc/rpc_debug.go new file mode 100644 index 0000000..a31e1f2 --- /dev/null +++ b/rpc/rpc_debug.go @@ -0,0 +1,20 @@ +package rpc + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/rpc_doc.go b/rpc/rpc_doc.go new file mode 100644 index 0000000..dbe7dda --- /dev/null +++ b/rpc/rpc_doc.go @@ -0,0 +1,118 @@ +// Package rpc implements zrepl daemon-to-daemon RPC protocol +// on top of a transport provided by package transport. +// The zrepl documentation refers to the client as the +// `active side` and to the server as the `passive side`. +// +// Design Considerations +// +// zrepl has non-standard requirements to remote procedure calls (RPC): +// whereas the coordination of replication (the planning phase) mostly +// consists of regular unary calls, the actual data transfer requires +// a high-throughput, low-overhead solution. +// +// Specifically, the requirements for data transfer is to perform +// a copy of an io.Reader over the wire, such that an io.EOF of the original +// reader corresponds to an io.EOF on the receiving side. +// If any other error occurs while reading from the original io.Reader +// on the sender, the receiver should receive the contents of that error +// in some form (usually as a trailer message) +// A common implementation technique for above data transfer is chunking, +// for example in HTTP: +// https://tools.ietf.org/html/rfc2616#section-3.6.1 +// +// With regards to regular RPCs, gRPC is a popular implementation based +// on protocol buffers and thus code generation. +// gRPC also supports bi-directional streaming RPC, and it is possible +// to implement chunked transfer through the use of streams. +// +// For zrepl however, both HTTP and manually implemented chunked transfer +// using gRPC were found to have significant CPU overhead at transfer +// speeds to be expected even with hobbyist users. +// +// However, it is nice to piggyback on the code generation provided +// by protobuf / gRPC, in particular since the majority of call types +// are regular unary RPCs for which the higher overhead of gRPC is acceptable. +// +// Hence, this package attempts to combine the best of both worlds: +// +// GRPC for Coordination and Dataconn for Bulk Data Transfer +// +// This package's Client uses its transport.Connecter to maintain +// separate control and data connections to the Server. +// The control connection is used by an instance of pdu.ReplicationClient +// whose 'regular' unary RPC calls are re-exported. +// The data connection is used by an instance of dataconn.Client and +// is used for bulk data transfer, namely `Send` and `Receive`. +// +// The following ASCII diagram gives an overview of how the individual +// building blocks are glued together: +// +// +------------+ +// | rpc.Client | +// +------------+ +// | | +// +--------+ +------------+ +// | | +// +---------v-----------+ +--------v------+ +// |pdu.ReplicationClient| |dataconn.Client| +// +---------------------+ +--------v------+ +// | label: label: | +// | zrepl_control zrepl_data | +// +--------+ +------------+ +// | | +// +--v---------v---+ +// | transportmux | +// +-------+--------+ +// | uses +// +-------v--------+ +// |versionhandshake| +// +-------+--------+ +// | uses +// +------v------+ +// | transport | +// +------+------+ +// | +// NETWORK +// | +// +------+------+ +// | transport | +// +------^------+ +// | uses +// +-------+--------+ +// |versionhandshake| +// +-------^--------+ +// | uses +// +-------+--------+ +// | transportmux | +// +--^--------^----+ +// | | +// +--------+ --------------+ --- +// | | | +// | label: label: | | +// | zrepl_control zrepl_data | | +// +-----+----+ +-----------+---+ | +// |netadaptor| |dataconn.Server| | rpc.Server +// | + | +------+--------+ | +// |grpcclient| | | +// |identity | | | +// +-----+----+ | | +// | | | +// +---------v-----------+ | | +// |pdu.ReplicationServer| | | +// +---------+-----------+ | | +// | | --- +// +----------+ +------------+ +// | | +// +---v--v-----+ +// | Handler | +// +------------+ +// (usually endpoint.{Sender,Receiver}) +// +// +package rpc + +// edit trick for the ASCII art above: +// - remove the comments // +// - vim: set virtualedit+=all +// - vim: set ft=text + diff --git a/rpc/rpc_logging.go b/rpc/rpc_logging.go new file mode 100644 index 0000000..e842f6e --- /dev/null +++ b/rpc/rpc_logging.go @@ -0,0 +1,34 @@ +package rpc + +import ( + "context" + + "github.com/zrepl/zrepl/logger" +) + +type Logger = logger.Logger + +type contextKey int + +const ( + contextKeyLoggers contextKey = iota + contextKeyGeneralLogger + contextKeyControlLogger + contextKeyDataLogger +) + +/// All fields must be non-nil +type Loggers struct { + General Logger + Control Logger + Data Logger +} + +func WithLoggers(ctx context.Context, loggers Loggers) context.Context { + ctx = context.WithValue(ctx, contextKeyLoggers, loggers) + return ctx +} + +func GetLoggersOrPanic(ctx context.Context) Loggers { + return ctx.Value(contextKeyLoggers).(Loggers) +} diff --git a/rpc/rpc_mux.go b/rpc/rpc_mux.go new file mode 100644 index 0000000..f70dde7 --- /dev/null +++ b/rpc/rpc_mux.go @@ -0,0 +1,57 @@ +package rpc + +import ( + "context" + "time" + + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/rpc/transportmux" + "github.com/zrepl/zrepl/util/envconst" +) + +type demuxedListener struct { + control, data transport.AuthenticatedListener +} + +const ( + transportmuxLabelControl string = "zrepl_control" + transportmuxLabelData string = "zrepl_data" +) + +func demux(serveCtx context.Context, listener transport.AuthenticatedListener) demuxedListener { + listeners, err := transportmux.Demux( + serveCtx, listener, + []string{transportmuxLabelControl, transportmuxLabelData}, + envconst.Duration("ZREPL_TRANSPORT_DEMUX_TIMEOUT", 10*time.Second), + ) + if err != nil { + // transportmux API guarantees that the returned error can only be due + // to invalid API usage (i.e. labels too long) + panic(err) + } + return demuxedListener{ + control: listeners[transportmuxLabelControl], + data: listeners[transportmuxLabelData], + } +} + +type muxedConnecter struct { + control, data transport.Connecter +} + +func mux(rawConnecter transport.Connecter) muxedConnecter { + muxedConnecters, err := transportmux.MuxConnecter( + rawConnecter, + []string{transportmuxLabelControl, transportmuxLabelData}, + envconst.Duration("ZREPL_TRANSPORT_MUX_TIMEOUT", 10*time.Second), + ) + if err != nil { + // transportmux API guarantees that the returned error can only be due + // to invalid API usage (i.e. labels too long) + panic(err) + } + return muxedConnecter{ + control: muxedConnecters[transportmuxLabelControl], + data: muxedConnecters[transportmuxLabelData], + } +} diff --git a/rpc/rpc_server.go b/rpc/rpc_server.go new file mode 100644 index 0000000..3abbc18 --- /dev/null +++ b/rpc/rpc_server.go @@ -0,0 +1,119 @@ +package rpc + +import ( + "context" + "time" + + "google.golang.org/grpc" + + "github.com/zrepl/zrepl/endpoint" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/grpcclientidentity" + "github.com/zrepl/zrepl/rpc/netadaptor" + "github.com/zrepl/zrepl/rpc/versionhandshake" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/envconst" +) + +type Handler interface { + pdu.ReplicationServer + dataconn.Handler +} + +type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error) + +// Server abstracts the accept and request routing infrastructure for the +// passive side of a replication setup. +type Server struct { + logger Logger + handler Handler + controlServer *grpc.Server + controlServerServe serveFunc + dataServer *dataconn.Server + dataServerServe serveFunc +} + +type serverContextKey int + +type HandlerContextInterceptor func(ctx context.Context) context.Context + +// config must be valid (use its Validate function). +func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server { + + // setup control server + tcs := grpcclientidentity.NewTransportCredentials(loggers.Control) // TODO different subsystem for log + unary, stream := grpcclientidentity.NewInterceptors(loggers.Control, endpoint.ClientIdentityKey) + controlServer := grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream)) + pdu.RegisterReplicationServer(controlServer, handler) + controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) { + // give time for graceful stop until deadline expires, then hard stop + go func() { + <-ctx.Done() + if dl, ok := ctx.Deadline(); ok { + go time.AfterFunc(dl.Sub(dl), controlServer.Stop) + } + loggers.Control.Debug("shutting down control server") + controlServer.GracefulStop() + }() + + errOut <- controlServer.Serve(netadaptor.New(controlListener, loggers.Control)) + } + + // setup data server + dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) { + ci := wire.ClientIdentity() + ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci) + if ctxInterceptor != nil { + ctx = ctxInterceptor(ctx) // SHADOWING + } + return ctx, wire + } + dataServer := dataconn.NewServer(dataServerClientIdentitySetter, loggers.Data, handler) + dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) { + dataServer.Serve(ctx, dataListener) + errOut <- nil // TODO bad design of dataServer? + } + + server := &Server{ + logger: loggers.General, + handler: handler, + controlServer: controlServer, + controlServerServe: controlServerServe, + dataServer: dataServer, + dataServerServe: dataServerServe, + } + + return server +} + +// The context is used for cancellation only. +// Serve never returns an error, it logs them to the Server's logger. +func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { + ctx, cancel := context.WithCancel(ctx) + + l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) + + // it is important that demux's context is cancelled, + // it has background goroutines attached + demuxListener := demux(ctx, l) + + serveErrors := make(chan error, 2) + go s.controlServerServe(ctx, demuxListener.control, serveErrors) + go s.dataServerServe(ctx, demuxListener.data, serveErrors) + select { + case serveErr := <-serveErrors: + s.logger.WithError(serveErr).Error("serve error") + s.logger.Debug("wait for other server to shut down") + cancel() + secondServeErr := <-serveErrors + s.logger.WithError(secondServeErr).Error("serve error") + case <-ctx.Done(): + s.logger.Debug("context cancelled, wait for control and data servers") + cancel() + for i := 0; i < 2; i++ { + <-serveErrors + } + s.logger.Debug("control and data server shut down, returning from Serve") + } +} diff --git a/rpc/transportmux/transportmux.go b/rpc/transportmux/transportmux.go new file mode 100644 index 0000000..cb6f7ca --- /dev/null +++ b/rpc/transportmux/transportmux.go @@ -0,0 +1,205 @@ +// Package transportmux wraps a transport.{Connecter,AuthenticatedListener} +// to distinguish different connection types based on a label +// sent from client to server on connection establishment. +// +// Labels are plain text and fixed length. +package transportmux + +import ( + "context" + "io" + "net" + "time" + "fmt" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/transport" +) + +type contextKey int + +const ( + contextKeyLog contextKey = 1 + iota +) + +type Logger = logger.Logger + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLog, log) +} + +func getLog(ctx context.Context) Logger { + if l, ok := ctx.Value(contextKeyLog).(Logger); ok { + return l + } + return logger.NewNullLogger() +} + +type acceptRes struct { + conn *transport.AuthConn + err error +} + +type demuxListener struct { + conns chan acceptRes +} + +func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + res := <-l.conns + return res.conn, res.err +} + +type demuxAddr struct {} + +func (demuxAddr) Network() string { return "demux" } +func (demuxAddr) String() string { return "demux" } + +func (l *demuxListener) Addr() net.Addr { + return demuxAddr{} +} + +func (l *demuxListener) Close() error { return nil } // TODO + +// Exact length of a label in bytes (0-byte padded if it is shorter). +// This is a protocol constant, changing it breaks the wire protocol. +const LabelLen = 64 + +func padLabel(out []byte, label string) (error) { + if len(label) > LabelLen { + return fmt.Errorf("label %q exceeds max length (is %d, max %d)", label, len(label), LabelLen) + } + if len(out) != LabelLen { + panic(fmt.Sprintf("implementation error: %d", out)) + } + labelBytes := []byte(label) + copy(out[:], labelBytes) + return nil +} + +func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, labels []string, timeout time.Duration) (map[string]transport.AuthenticatedListener, error) { + + padded := make(map[[64]byte]*demuxListener, len(labels)) + ret := make(map[string]transport.AuthenticatedListener, len(labels)) + for _, label := range labels { + var labelPadded [LabelLen]byte + err := padLabel(labelPadded[:], label) + if err != nil { + return nil, err + } + if _, ok := padded[labelPadded]; ok { + return nil, fmt.Errorf("duplicate label %q", label) + } + dl := &demuxListener{make(chan acceptRes)} + padded[labelPadded] = dl + ret[label] = dl + } + + // invariant: padded contains same-length, non-duplicate labels + + go func() { + <-ctx.Done() + getLog(ctx).Debug("context cancelled, closing listener") + if err := rawListener.Close(); err != nil { + getLog(ctx).WithError(err).Error("error closing listener") + } + }() + + go func() { + for { + rawConn, err := rawListener.Accept(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + getLog(ctx).WithError(err).Error("accept error") + continue + } + closeConn := func() { + if err := rawConn.Close(); err != nil { + getLog(ctx).WithError(err).Error("cannot close conn") + } + } + + if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil { + getLog(ctx).WithError(err).Error("SetDeadline failed") + closeConn() + continue + } + + var labelBuf [LabelLen]byte + if _, err := io.ReadFull(rawConn, labelBuf[:]); err != nil { + getLog(ctx).WithError(err).Error("error reading label") + closeConn() + continue + } + + demuxListener, ok := padded[labelBuf] + if !ok { + getLog(ctx).WithError(err). + WithField("client_label", fmt.Sprintf("%q", labelBuf)). + Error("unknown client label") + closeConn() + continue + } + + rawConn.SetDeadline(time.Time{}) + // blocking is intentional + demuxListener.conns <- acceptRes{conn: rawConn, err: nil} + } + }() + + return ret, nil +} + +type labeledConnecter struct { + label []byte + transport.Connecter +} + +func (c labeledConnecter) Connect(ctx context.Context) (transport.Wire, error) { + conn, err := c.Connecter.Connect(ctx) + if err != nil { + return nil, err + } + closeConn := func(why error) { + getLog(ctx).WithField("reason", why.Error()).Debug("closing connection") + if err := conn.Close(); err != nil { + getLog(ctx).WithError(err).Error("error closing connection after label write error") + } + } + + if dl, ok := ctx.Deadline(); ok { + defer conn.SetDeadline(time.Time{}) + if err := conn.SetDeadline(dl); err != nil { + closeConn(err) + return nil, err + } + } + n, err := conn.Write(c.label) + if err != nil { + closeConn(err) + return nil, err + } + if n != len(c.label) { + closeConn(fmt.Errorf("short label write")) + return nil, io.ErrShortWrite + } + return conn, nil +} + +func MuxConnecter(rawConnecter transport.Connecter, labels []string, timeout time.Duration) (map[string]transport.Connecter, error) { + ret := make(map[string]transport.Connecter, len(labels)) + for _, label := range labels { + var paddedLabel [LabelLen]byte + if err := padLabel(paddedLabel[:], label); err != nil { + return nil, err + } + lc := &labeledConnecter{paddedLabel[:], rawConnecter} + if _, ok := ret[label]; ok { + return nil, fmt.Errorf("duplicate label %q", label) + } + ret[label] = lc + } + return ret, nil +} + diff --git a/rpc/versionhandshake/versionhandshake.go b/rpc/versionhandshake/versionhandshake.go new file mode 100644 index 0000000..3864868 --- /dev/null +++ b/rpc/versionhandshake/versionhandshake.go @@ -0,0 +1,181 @@ +// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener} +// to add an exchange of protocol version information on connection establishment. +// +// The protocol version information (banner) is plain text, thus making it +// easy to diagnose issues with standard tools. +package versionhandshake + +import ( + "bytes" + "fmt" + "io" + "net" + "strings" + "time" + "unicode/utf8" +) + +type HandshakeMessage struct { + ProtocolVersion int + Extensions []string +} + +// A HandshakeError describes what went wrong during the handshake. +// It implements net.Error and is always temporary. +type HandshakeError struct { + msg string + // If not nil, the underlying IO error that caused the handshake to fail. + IOError error +} + +var _ net.Error = &HandshakeError{} + +func (e HandshakeError) Error() string { return e.msg } + +// Always true to enable usage in a net.Listener. +func (e HandshakeError) Temporary() bool { return true } + +// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value. +// Otherwise false. +func (e HandshakeError) Timeout() bool { + if neterr, ok := e.IOError.(net.Error); ok { + return neterr.Timeout() + } + return false +} + +func hsErr(format string, args... interface{}) *HandshakeError { + return &HandshakeError{msg: fmt.Sprintf(format, args...)} +} + +func hsIOErr(err error, format string, args... interface{}) *HandshakeError { + return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)} +} + +// MaxProtocolVersion is the maximum allowed protocol version. +// This is a protocol constant, changing it may break the wire format. +const MaxProtocolVersion = 9999 + +// Only returns *HandshakeError as error. +func (m *HandshakeMessage) Encode() ([]byte, error) { + if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion { + return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion)) + } + if len(m.Extensions) >= MaxProtocolVersion { + return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion)) + } + // EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions + var extensions strings.Builder + for i, ext := range m.Extensions { + if strings.ContainsAny(ext, "\n") { + return nil, hsErr("Extension #%d contains forbidden newline character", i) + } + if !utf8.ValidString(ext) { + return nil, hsErr("Extension #%d is not valid UTF-8", i) + } + extensions.WriteString(ext) + extensions.WriteString("\n") + } + withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s", + m.ProtocolVersion, len(m.Extensions), extensions.String()) + withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen) + return []byte(withLen), nil +} + +func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error { + var lenAndSpace [11]byte + if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil { + return hsIOErr(err, "error reading protocol banner length: %s", err) + } + if !utf8.Valid(lenAndSpace[:]) { + return hsErr("invalid start of handshake message: not valid UTF-8") + } + var followLen int + n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen) + if n != 1 || err != nil { + return hsErr("could not parse handshake message length") + } + if followLen > maxLen { + return hsErr("handshake message length exceeds max length (%d vs %d)", + followLen, maxLen) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, io.LimitReader(r, int64(followLen))) + if err != nil { + return hsIOErr(err, "error reading protocol banner body: %s", err) + } + + var ( + protoVersion, extensionCount int + ) + n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n", + &protoVersion, &extensionCount) + if n != 2 || err != nil { + return hsErr("could not parse handshake message: %s", err) + } + if protoVersion < 1 { + return hsErr("invalid protocol version %q", protoVersion) + } + m.ProtocolVersion = protoVersion + + if extensionCount < 0 { + return hsErr("invalid extension count %q", extensionCount) + } + if extensionCount == 0 { + if buf.Len() != 0 { + return hsErr("unexpected data trailing after header") + } + m.Extensions = nil + return nil + } + s := buf.String() + if strings.Count(s, "\n") != extensionCount { + return hsErr("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount) + } + exts := strings.Split(s, "\n") + if exts[len(exts)-1] != "" { + return hsErr("unexpected data trailing after last extension newline") + } + m.Extensions = exts[0:len(exts)-1] + + return nil +} + +func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error { + // current protocol version is hardcoded here + return DoHandshakeVersion(conn, deadline, 1) +} + +const HandshakeMessageMaxLen = 16 * 4096 + +func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error { + ours := HandshakeMessage{ + ProtocolVersion: version, + Extensions: nil, + } + hsb, err := ours.Encode() + if err != nil { + return hsErr("could not encode protocol banner: %s", err) + } + + defer conn.SetDeadline(time.Time{}) + conn.SetDeadline(deadline) + _, err = io.Copy(conn, bytes.NewBuffer(hsb)) + if err != nil { + return hsErr("could not send protocol banner: %s", err) + } + + theirs := HandshakeMessage{} + if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil { + return hsErr("could not decode protocol banner: %s", err) + } + + if theirs.ProtocolVersion != ours.ProtocolVersion { + return hsErr("protocol versions do not match: ours is %d, theirs is %d", + ours.ProtocolVersion, theirs.ProtocolVersion) + } + // ignore extensions, we don't use them + + return nil +} diff --git a/daemon/transport/handshake_test.go b/rpc/versionhandshake/versionhandshake_test.go similarity index 99% rename from daemon/transport/handshake_test.go rename to rpc/versionhandshake/versionhandshake_test.go index d1c72b4..dd27c9d 100644 --- a/daemon/transport/handshake_test.go +++ b/rpc/versionhandshake/versionhandshake_test.go @@ -1,4 +1,4 @@ -package transport +package versionhandshake import ( "bytes" diff --git a/rpc/versionhandshake/versionhandshake_transport_wrappers.go b/rpc/versionhandshake/versionhandshake_transport_wrappers.go new file mode 100644 index 0000000..660215e --- /dev/null +++ b/rpc/versionhandshake/versionhandshake_transport_wrappers.go @@ -0,0 +1,66 @@ +package versionhandshake + +import ( + "context" + "net" + "time" + "github.com/zrepl/zrepl/transport" +) + +type HandshakeConnecter struct { + connecter transport.Connecter + timeout time.Duration +} + +func (c HandshakeConnecter) Connect(ctx context.Context) (transport.Wire, error) { + conn, err := c.connecter.Connect(ctx) + if err != nil { + return nil, err + } + dl, ok := ctx.Deadline() + if !ok { + dl = time.Now().Add(c.timeout) + } + if err := DoHandshakeCurrentVersion(conn, dl); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func Connecter(connecter transport.Connecter, timeout time.Duration) HandshakeConnecter { + return HandshakeConnecter{ + connecter: connecter, + timeout: timeout, + } +} + +// wrapper type that performs a a protocol version handshake before returning the connection +type HandshakeListener struct { + l transport.AuthenticatedListener + timeout time.Duration +} + +func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() } + +func (l HandshakeListener) Close() error { return l.l.Close() } + +func (l HandshakeListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + conn, err := l.l.Accept(ctx) + if err != nil { + return nil, err + } + dl, ok := ctx.Deadline() + if !ok { + dl = time.Now().Add(l.timeout) // shadowing + } + if err := DoHandshakeCurrentVersion(conn, dl); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func Listener(l transport.AuthenticatedListener, timeout time.Duration) transport.AuthenticatedListener { + return HandshakeListener{l, timeout} +} diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index 38f0734..a5a4ea5 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -4,8 +4,11 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" + "io" "io/ioutil" "net" + "os" "time" ) @@ -22,12 +25,13 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) { } type ClientAuthListener struct { - l net.Listener + l *net.TCPListener + c *tls.Config handshakeTimeout time.Duration } func NewClientAuthListener( - l net.Listener, ca *x509.CertPool, serverCert tls.Certificate, + l *net.TCPListener, ca *x509.CertPool, serverCert tls.Certificate, handshakeTimeout time.Duration) *ClientAuthListener { if ca == nil { @@ -37,29 +41,35 @@ func NewClientAuthListener( panic(serverCert) } - tlsConf := tls.Config{ + tlsConf := &tls.Config{ Certificates: []tls.Certificate{serverCert}, ClientCAs: ca, ClientAuth: tls.RequireAndVerifyClientCert, PreferServerCipherSuites: true, + KeyLogWriter: keylogFromEnv(), } - l = tls.NewListener(l, &tlsConf) return &ClientAuthListener{ l, + tlsConf, handshakeTimeout, } } -func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { - c, err = l.l.Accept() +// Accept() accepts a connection from the *net.TCPListener passed to the constructor +// and sets up the TLS connection, including handshake and peer CommmonName validation +// within the specified handshakeTimeout. +// +// It returns both the raw TCP connection (tcpConn) and the TLS connection (tlsConn) on top of it. +// Access to the raw tcpConn might be necessary if CloseWrite semantics are desired: +// tlsConn.CloseWrite does NOT call tcpConn.CloseWrite, hence we provide access to tcpConn to +// allow the caller to do this by themselves. +func (l *ClientAuthListener) Accept() (tcpConn *net.TCPConn, tlsConn *tls.Conn, clientCN string, err error) { + tcpConn, err = l.l.AcceptTCP() if err != nil { - return nil, "", err - } - tlsConn, ok := c.(*tls.Conn) - if !ok { - return c, "", err + return nil, nil, "", err } + tlsConn = tls.Server(tcpConn, l.c) var ( cn string peerCerts []*x509.Certificate @@ -70,6 +80,7 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { if err = tlsConn.Handshake(); err != nil { goto CloseAndErr } + tlsConn.SetDeadline(time.Time{}) peerCerts = tlsConn.ConnectionState().PeerCertificates if len(peerCerts) < 1 { @@ -77,10 +88,11 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { goto CloseAndErr } cn = peerCerts[0].Subject.CommonName - return c, cn, nil + return tcpConn, tlsConn, cn, nil CloseAndErr: - c.Close() - return nil, "", err + // unlike CloseWrite, Close on *tls.Conn actually closes the underlying connection + tlsConn.Close() // TODO log error + return nil, nil, "", err } func (l *ClientAuthListener) Addr() net.Addr { @@ -105,7 +117,21 @@ func ClientAuthClient(serverName string, rootCA *x509.CertPool, clientCert tls.C Certificates: []tls.Certificate{clientCert}, RootCAs: rootCA, ServerName: serverName, + KeyLogWriter: keylogFromEnv(), } tlsConfig.BuildNameToCertificate() return tlsConfig, nil } + +func keylogFromEnv() io.Writer { + var keyLog io.Writer = nil + if outfile := os.Getenv("ZREPL_KEYLOG_FILE"); outfile != "" { + fmt.Fprintf(os.Stderr, "writing to key log %s\n", outfile) + var err error + keyLog, err = os.OpenFile(outfile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(err) + } + } + return keyLog +} diff --git a/transport/fromconfig/transport_fromconfig.go b/transport/fromconfig/transport_fromconfig.go new file mode 100644 index 0000000..0aa1426 --- /dev/null +++ b/transport/fromconfig/transport_fromconfig.go @@ -0,0 +1,58 @@ +// Package fromconfig instantiates transports based on zrepl config structures +// (see package config). +package fromconfig + +import ( + "fmt" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/local" + "github.com/zrepl/zrepl/transport/ssh" + "github.com/zrepl/zrepl/transport/tcp" + "github.com/zrepl/zrepl/transport/tls" +) + +func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory,error) { + + var ( + l transport.AuthenticatedListenerFactory + err error + ) + switch v := in.Ret.(type) { + case *config.TCPServe: + l, err = tcp.TCPListenerFactoryFromConfig(g, v) + case *config.TLSServe: + l, err = tls.TLSListenerFactoryFromConfig(g, v) + case *config.StdinserverServer: + l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v) + case *config.LocalServe: + l, err = local.LocalListenerFactoryFromConfig(g, v) + default: + return nil, errors.Errorf("internal error: unknown serve type %T", v) + } + + return l, err +} + + +func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) { + var ( + connecter transport.Connecter + err error + ) + switch v := in.Ret.(type) { + case *config.SSHStdinserverConnect: + connecter, err = ssh.SSHStdinserverConnecterFromConfig(v) + case *config.TCPConnect: + connecter, err = tcp.TCPConnecterFromConfig(v) + case *config.TLSConnect: + connecter, err = tls.TLSConnecterFromConfig(v) + case *config.LocalConnect: + connecter, err = local.LocalConnecterFromConfig(v) + default: + panic(fmt.Sprintf("implementation error: unknown connecter type %T", v)) + } + + return connecter, err +} diff --git a/daemon/transport/connecter/connect_local.go b/transport/local/connect_local.go similarity index 72% rename from daemon/transport/connecter/connect_local.go rename to transport/local/connect_local.go index 45c3d68..ba390b8 100644 --- a/daemon/transport/connecter/connect_local.go +++ b/transport/local/connect_local.go @@ -1,11 +1,10 @@ -package connecter +package local import ( "context" "fmt" "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/transport/serve" - "net" + "github.com/zrepl/zrepl/transport" ) type LocalConnecter struct { @@ -23,8 +22,8 @@ func LocalConnecterFromConfig(in *config.LocalConnect) (*LocalConnecter, error) return &LocalConnecter{listenerName: in.ListenerName, clientIdentity: in.ClientIdentity}, nil } -func (c *LocalConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - l := serve.GetLocalListener(c.listenerName) +func (c *LocalConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + l := GetLocalListener(c.listenerName) return l.Connect(dialCtx, c.clientIdentity) } diff --git a/daemon/transport/serve/serve_local.go b/transport/local/serve_local.go similarity index 74% rename from daemon/transport/serve/serve_local.go rename to transport/local/serve_local.go index f71ba70..f7e42aa 100644 --- a/daemon/transport/serve/serve_local.go +++ b/transport/local/serve_local.go @@ -1,4 +1,4 @@ -package serve +package local import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/zrepl/zrepl/util/socketpair" "net" "sync" + "github.com/zrepl/zrepl/transport" ) var localListeners struct { @@ -39,7 +40,7 @@ type connectRequest struct { } type connectResult struct { - conn net.Conn + conn transport.Wire err error } @@ -54,7 +55,7 @@ func newLocalListener() *LocalListener { } // Connect to the LocalListener from a client with identity clientIdentity -func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn net.Conn, err error) { +func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn transport.Wire, err error) { // place request req := connectRequest{ @@ -89,21 +90,14 @@ func (a localAddr) String() string { return a.S } func (l *LocalListener) Addr() (net.Addr) { return localAddr{""} } -type localConn struct { - net.Conn - clientIdentity string -} - -func (l localConn) ClientIdentity() string { return l.clientIdentity } - -func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { +func (l *LocalListener) Accept(ctx context.Context) (*transport.AuthConn, error) { respondToRequest := func(req connectRequest, res connectResult) (err error) { - getLogger(ctx). + transport.GetLogger(ctx). WithField("res.conn", res.conn).WithField("res.err", res.err). Debug("responding to client request") defer func() { errv := recover() - getLogger(ctx).WithField("recover_err", errv). + transport.GetLogger(ctx).WithField("recover_err", errv). Debug("panic on send to client callback, likely a legitimate client-side timeout") }() select { @@ -116,7 +110,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return err } - getLogger(ctx).Debug("waiting for local client connect requests") + transport.GetLogger(ctx).Debug("waiting for local client connect requests") var req connectRequest select { case req = <-l.connects: @@ -124,7 +118,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return nil, ctx.Err() } - getLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request") + transport.GetLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request") if req.clientIdentity == "" { res := connectResult{nil, fmt.Errorf("client identity must not be empty")} if err := respondToRequest(req, res); err != nil { @@ -133,31 +127,31 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return nil, fmt.Errorf("client connected with empty client identity") } - getLogger(ctx).Debug("creating socketpair") + transport.GetLogger(ctx).Debug("creating socketpair") left, right, err := socketpair.SocketPair() if err != nil { res := connectResult{nil, fmt.Errorf("server error: %s", err)} if respErr := respondToRequest(req, res); respErr != nil { // returning the socketpair error properly is more important than the error sent to the client - getLogger(ctx).WithError(respErr).Error("error responding to client") + transport.GetLogger(ctx).WithError(respErr).Error("error responding to client") } return nil, err } - getLogger(ctx).Debug("responding with left side of socketpair") + transport.GetLogger(ctx).Debug("responding with left side of socketpair") res := connectResult{left, nil} if err := respondToRequest(req, res); err != nil { - getLogger(ctx).WithError(err).Error("error responding to client") + transport.GetLogger(ctx).WithError(err).Error("error responding to client") if err := left.Close(); err != nil { - getLogger(ctx).WithError(err).Error("cannot close left side of socketpair") + transport.GetLogger(ctx).WithError(err).Error("cannot close left side of socketpair") } if err := right.Close(); err != nil { - getLogger(ctx).WithError(err).Error("cannot close right side of socketpair") + transport.GetLogger(ctx).WithError(err).Error("cannot close right side of socketpair") } return nil, err } - return localConn{right, req.clientIdentity}, nil + return transport.NewAuthConn(right, req.clientIdentity), nil } func (l *LocalListener) Close() error { @@ -169,19 +163,13 @@ func (l *LocalListener) Close() error { return nil } -type LocalListenerFactory struct { - listenerName string -} - -func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (f *LocalListenerFactory, err error) { +func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (transport.AuthenticatedListenerFactory,error) { if in.ListenerName == "" { return nil, fmt.Errorf("ListenerName must not be empty") } - return &LocalListenerFactory{listenerName: in.ListenerName}, nil + listenerName := in.ListenerName + lf := func() (transport.AuthenticatedListener,error) { + return GetLocalListener(listenerName), nil + } + return lf, nil } - - -func (lf *LocalListenerFactory) Listen() (AuthenticatedListener, error) { - return GetLocalListener(lf.listenerName), nil -} - diff --git a/daemon/transport/connecter/connect_ssh.go b/transport/ssh/connect_ssh.go similarity index 72% rename from daemon/transport/connecter/connect_ssh.go rename to transport/ssh/connect_ssh.go index 7efeec5..d669b88 100644 --- a/daemon/transport/connecter/connect_ssh.go +++ b/transport/ssh/connect_ssh.go @@ -1,13 +1,12 @@ -package connecter +package ssh import ( "context" "github.com/jinzhu/copier" "github.com/pkg/errors" "github.com/problame/go-netssh" - "github.com/problame/go-streamrpc" "github.com/zrepl/zrepl/config" - "net" + "github.com/zrepl/zrepl/transport" "time" ) @@ -22,8 +21,6 @@ type SSHStdinserverConnecter struct { dialTimeout time.Duration } -var _ streamrpc.Connecter = &SSHStdinserverConnecter{} - func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSHStdinserverConnecter, err error) { c = &SSHStdinserverConnecter{ @@ -39,15 +36,7 @@ func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSH } -type netsshConnToConn struct{ *netssh.SSHConn } - -var _ net.Conn = netsshConnToConn{} - -func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil } -func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil } -func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil } - -func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) { +func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { var endpoint netssh.Endpoint if err := copier.Copy(&endpoint, c); err != nil { @@ -62,5 +51,5 @@ func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, er } return nil, err } - return netsshConnToConn{nconn}, nil + return nconn, nil } diff --git a/daemon/transport/serve/serve_stdinserver.go b/transport/ssh/serve_stdinserver.go similarity index 57% rename from daemon/transport/serve/serve_stdinserver.go rename to transport/ssh/serve_stdinserver.go index f02bf20..39bfba8 100644 --- a/daemon/transport/serve/serve_stdinserver.go +++ b/transport/ssh/serve_stdinserver.go @@ -1,50 +1,38 @@ -package serve +package ssh import ( "github.com/problame/go-netssh" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/nethelpers" - "io" + "github.com/zrepl/zrepl/transport" + "fmt" "net" "path" - "time" "context" "github.com/pkg/errors" "sync/atomic" ) -type StdinserverListenerFactory struct { - ClientIdentities []string - Sockdir string -} - -func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) { +func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory,error) { for _, ci := range in.ClientIdentities { - if err := ValidateClientIdentity(ci); err != nil { + if err := transport.ValidateClientIdentity(ci); err != nil { return nil, errors.Wrapf(err, "invalid client identity %q", ci) } } - f = &multiStdinserverListenerFactory{ - ClientIdentities: in.ClientIdentities, - Sockdir: g.Serve.StdinServer.SockDir, + clientIdentities := in.ClientIdentities + sockdir := g.Serve.StdinServer.SockDir + + lf := func() (transport.AuthenticatedListener,error) { + return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities) } - return -} - -type multiStdinserverListenerFactory struct { - ClientIdentities []string - Sockdir string -} - -func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) { - return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities) + return lf, nil } type multiStdinserverAcceptRes struct { - conn AuthenticatedConn + conn *transport.AuthConn err error } @@ -78,7 +66,7 @@ func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) return &MultiStdinserverListener{listeners: listeners}, nil } -func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){ +func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error){ if m.accepts == nil { m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners)) @@ -97,8 +85,22 @@ func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedCon } -func (m *MultiStdinserverListener) Addr() (net.Addr) { - return netsshAddr{} +type multiListenerAddr struct { + clients []string +} + +func (multiListenerAddr) Network() string { return "netssh" } + +func (l multiListenerAddr) String() string { + return fmt.Sprintf("netssh:clients=%v", l.clients) +} + +func (m *MultiStdinserverListener) Addr() net.Addr { + cis := make([]string, len(m.listeners)) + for i := range cis { + cis[i] = m.listeners[i].clientIdentity + } + return multiListenerAddr{cis} } func (m *MultiStdinserverListener) Close() error { @@ -118,41 +120,28 @@ type stdinserverListener struct { clientIdentity string } -func (l stdinserverListener) Addr() net.Addr { - return netsshAddr{} +type listenerAddr struct { + clientIdentity string } -func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) { +func (listenerAddr) Network() string { return "netssh" } + +func (a listenerAddr) String() string { + return fmt.Sprintf("netssh:client=%q", a.clientIdentity) +} + +func (l stdinserverListener) Addr() net.Addr { + return listenerAddr{l.clientIdentity} +} + +func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) { c, err := l.l.Accept() if err != nil { return nil, err } - return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil + return transport.NewAuthConn(c, l.clientIdentity), nil } func (l stdinserverListener) Close() (err error) { return l.l.Close() } - -type netsshAddr struct{} - -func (netsshAddr) Network() string { return "netssh" } -func (netsshAddr) String() string { return "???" } - -type netsshConnToNetConnAdatper struct { - io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn - clientIdentity string -} - -func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity } - -func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} } - -func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} } - -// FIXME log warning once! -func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil } - -func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil } - -func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil } diff --git a/daemon/transport/connecter/connect_tcp.go b/transport/tcp/connect_tcp.go similarity index 54% rename from daemon/transport/connecter/connect_tcp.go rename to transport/tcp/connect_tcp.go index 3d8b77e..1176512 100644 --- a/daemon/transport/connecter/connect_tcp.go +++ b/transport/tcp/connect_tcp.go @@ -1,9 +1,11 @@ -package connecter +package tcp import ( "context" - "github.com/zrepl/zrepl/config" "net" + + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" ) type TCPConnecter struct { @@ -19,6 +21,10 @@ func TCPConnecterFromConfig(in *config.TCPConnect) (*TCPConnecter, error) { return &TCPConnecter{in.Address, dialer}, nil } -func (c *TCPConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - return c.dialer.DialContext(dialCtx, "tcp", c.Address) +func (c *TCPConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil } diff --git a/daemon/transport/serve/serve_tcp.go b/transport/tcp/serve_tcp.go similarity index 68% rename from daemon/transport/serve/serve_tcp.go rename to transport/tcp/serve_tcp.go index 957d3b9..a6b8107 100644 --- a/daemon/transport/serve/serve_tcp.go +++ b/transport/tcp/serve_tcp.go @@ -1,17 +1,13 @@ -package serve +package tcp import ( "github.com/zrepl/zrepl/config" "net" "github.com/pkg/errors" "context" + "github.com/zrepl/zrepl/transport" ) -type TCPListenerFactory struct { - address *net.TCPAddr - clientMap *ipMap -} - type ipMapEntry struct { ip net.IP ident string @@ -28,7 +24,7 @@ func ipMapFromConfig(clients map[string]string) (*ipMap, error) { if clientIP == nil { return nil, errors.Errorf("cannot parse client IP %q", clientIPString) } - if err := ValidateClientIdentity(clientIdent); err != nil { + if err := transport.ValidateClientIdentity(clientIdent); err != nil { return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString) } entries = append(entries, ipMapEntry{clientIP, clientIdent}) @@ -45,7 +41,7 @@ func (m *ipMap) Get(ip net.IP) (string, error) { return "", errors.Errorf("no identity mapping for client IP %s", ip) } -func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) { +func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (transport.AuthenticatedListenerFactory, error) { addr, err := net.ResolveTCPAddr("tcp", in.Listen) if err != nil { return nil, errors.Wrap(err, "cannot parse listen address") @@ -54,38 +50,33 @@ func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPLi if err != nil { return nil, errors.Wrap(err, "cannot parse client IP map") } - lf := &TCPListenerFactory{ - address: addr, - clientMap: clientMap, + lf := func() (transport.AuthenticatedListener, error) { + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + return &TCPAuthListener{l, clientMap}, nil } return lf, nil } -func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := net.ListenTCP("tcp", f.address) - if err != nil { - return nil, err - } - return &TCPAuthListener{l, f.clientMap}, nil -} - type TCPAuthListener struct { *net.TCPListener clientMap *ipMap } -func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - nc, err := f.TCPListener.Accept() +func (f *TCPAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + nc, err := f.TCPListener.AcceptTCP() if err != nil { return nil, err } clientIP := nc.RemoteAddr().(*net.TCPAddr).IP clientIdent, err := f.clientMap.Get(clientIP) if err != nil { - getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") + transport.GetLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") nc.Close() return nil, err } - return authConn{nc, clientIdent}, nil + return transport.NewAuthConn(nc, clientIdent), nil } diff --git a/daemon/transport/connecter/connect_tls.go b/transport/tls/connect_tls.go similarity index 73% rename from daemon/transport/connecter/connect_tls.go rename to transport/tls/connect_tls.go index a60cb45..ea578d4 100644 --- a/daemon/transport/connecter/connect_tls.go +++ b/transport/tls/connect_tls.go @@ -1,12 +1,14 @@ -package connecter +package tls import ( "context" "crypto/tls" + "net" + "github.com/pkg/errors" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/tlsconf" - "net" + "github.com/zrepl/zrepl/transport" ) type TLSConnecter struct { @@ -38,10 +40,12 @@ func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) { return &TLSConnecter{in.Address, dialer, tlsConfig}, nil } -func (c *TLSConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - conn, err = c.dialer.DialContext(dialCtx, "tcp", c.Address) +func (c *TLSConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address) if err != nil { return nil, err } - return tls.Client(conn, c.tlsConfig), nil + tcpConn := conn.(*net.TCPConn) + tlsConn := tls.Client(conn, c.tlsConfig) + return newWireAdaptor(tlsConn, tcpConn), nil } diff --git a/transport/tls/serve_tls.go b/transport/tls/serve_tls.go new file mode 100644 index 0000000..21aafe4 --- /dev/null +++ b/transport/tls/serve_tls.go @@ -0,0 +1,89 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/tlsconf" + "net" + "time" + "context" +) + +type TLSListenerFactory struct { + address string + clientCA *x509.CertPool + serverCert tls.Certificate + handshakeTimeout time.Duration + clientCNs map[string]struct{} +} + +func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory,error) { + + address := in.Listen + handshakeTimeout := in.HandshakeTimeout + + if in.Ca == "" || in.Cert == "" || in.Key == "" { + return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") + } + + clientCA, err := tlsconf.ParseCAFile(in.Ca) + if err != nil { + return nil, errors.Wrap(err, "cannot parse ca file") + } + + serverCert, err := tls.LoadX509KeyPair(in.Cert, in.Key) + if err != nil { + return nil, errors.Wrap(err, "cannot parse cer/key pair") + } + + clientCNs := make(map[string]struct{}, len(in.ClientCNs)) + for i, cn := range in.ClientCNs { + if err := transport.ValidateClientIdentity(cn); err != nil { + return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn) + } + // dupes are ok fr now + clientCNs[cn] = struct{}{} + } + + lf := func() (transport.AuthenticatedListener, error) { + l, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + tcpL := l.(*net.TCPListener) + tl := tlsconf.NewClientAuthListener(tcpL, clientCA, serverCert, handshakeTimeout) + return &tlsAuthListener{tl, clientCNs}, nil + } + + return lf, nil +} + +type tlsAuthListener struct { + *tlsconf.ClientAuthListener + clientCNs map[string]struct{} +} + +func (l tlsAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + tcpConn, tlsConn, cn, err := l.ClientAuthListener.Accept() + if err != nil { + return nil, err + } + if _, ok := l.clientCNs[cn]; !ok { + if dl, ok := ctx.Deadline(); ok { + defer tlsConn.SetDeadline(time.Time{}) + tlsConn.SetDeadline(dl) + } + if err := tlsConn.Close(); err != nil { + transport.GetLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name") + } + return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, tlsConn.RemoteAddr()) + } + adaptor := newWireAdaptor(tlsConn, tcpConn) + return transport.NewAuthConn(adaptor, cn), nil +} + + diff --git a/transport/tls/tls_wire_adaptor.go b/transport/tls/tls_wire_adaptor.go new file mode 100644 index 0000000..2f03cdd --- /dev/null +++ b/transport/tls/tls_wire_adaptor.go @@ -0,0 +1,47 @@ +package tls + +import ( + "crypto/tls" + "fmt" + "net" + "os" +) + +// adapts a tls.Conn and its underlying net.TCPConn into a valid transport.Wire +type transportWireAdaptor struct { + *tls.Conn + tcpConn *net.TCPConn +} + +func newWireAdaptor(tlsConn *tls.Conn, tcpConn *net.TCPConn) transportWireAdaptor { + return transportWireAdaptor{tlsConn, tcpConn} +} + +// CloseWrite implements transport.Wire.CloseWrite which is different from *tls.Conn.CloseWrite: +// the former requires that the other side observes io.EOF, but *tls.Conn.CloseWrite does not +// close the underlying connection so no io.EOF would be observed. +func (w transportWireAdaptor) CloseWrite() error { + if err := w.Conn.CloseWrite(); err != nil { + // TODO log error + fmt.Fprintf(os.Stderr, "transport/tls.CloseWrite() error: %s\n", err) + } + return w.tcpConn.CloseWrite() +} + +// Close implements transport.Wire.Close which is different from a *tls.Conn.Close: +// At the time of writing (Go 1.11), closing tls.Conn closes the TCP connection immediately, +// which results in io.ErrUnexpectedEOF on the other side. +// We assume that w.Conn has a deadline set for the close, so the CloseWrite will time out if it blocks, +// falling through to the actual Close() +func (w transportWireAdaptor) Close() error { + // var buf [1<<15]byte + // w.Conn.Write(buf[:]) + // CloseWrite will send a TLS alert record down the line which + // in the Go implementation acts like a flush...? + // if err := w.Conn.CloseWrite(); err != nil { + // // TODO log error + // fmt.Fprintf(os.Stderr, "transport/tls.Close() close write error: %s\n", err) + // } + // time.Sleep(1 * time.Second) + return w.Conn.Close() +} diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..e520928 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,84 @@ +// Package transport defines a common interface for +// network connections that have an associated client identity. +package transport + +import ( + "context" + "errors" + "net" + "syscall" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/zfs" +) + +type AuthConn struct { + Wire + clientIdentity string +} + +var _ timeoutconn.SyscallConner = AuthConn{} + +func (a AuthConn) SyscallConn() (rawConn syscall.RawConn, err error) { + scc, ok := a.Wire.(timeoutconn.SyscallConner) + if !ok { + return nil, timeoutconn.SyscallConnNotSupported + } + return scc.SyscallConn() +} + +func NewAuthConn(conn Wire, clientIdentity string) *AuthConn { + return &AuthConn{conn, clientIdentity} +} + +func (c *AuthConn) ClientIdentity() string { + if err := ValidateClientIdentity(c.clientIdentity); err != nil { + panic(err) + } + return c.clientIdentity +} + +// like net.Listener, but with an AuthenticatedConn instead of net.Conn +type AuthenticatedListener interface { + Addr() net.Addr + Accept(ctx context.Context) (*AuthConn, error) + Close() error +} + +type AuthenticatedListenerFactory func() (AuthenticatedListener, error) + +type Wire = timeoutconn.Wire + +type Connecter interface { + Connect(ctx context.Context) (Wire, error) +} + +// A client identity must be a single component in a ZFS filesystem path +func ValidateClientIdentity(in string) (err error) { + path, err := zfs.NewDatasetPath(in) + if err != nil { + return err + } + if path.Length() != 1 { + return errors.New("client identity must be a single path comonent (not empty, no '/')") + } + return nil +} + +type contextKey int + +const contextKeyLog contextKey = 0 + +type Logger = logger.Logger + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLog, log) +} + +func GetLogger(ctx context.Context) Logger { + if log, ok := ctx.Value(contextKeyLog).(Logger); ok { + return log + } + return logger.NewNullLogger() +} diff --git a/util/bytecounter/bytecounter_streamcopier.go b/util/bytecounter/bytecounter_streamcopier.go new file mode 100644 index 0000000..a268a91 --- /dev/null +++ b/util/bytecounter/bytecounter_streamcopier.go @@ -0,0 +1,71 @@ +package bytecounter + +import ( + "io" + "sync/atomic" + + "github.com/zrepl/zrepl/zfs" +) + +// StreamCopier wraps a zfs.StreamCopier, reimplemening +// its interface and counting the bytes written to during copying. +type StreamCopier interface { + zfs.StreamCopier + Count() int64 +} + +// NewStreamCopier wraps sc into a StreamCopier. +// If sc is io.Reader, it is guaranteed that the returned StreamCopier +// implements that interface, too. +func NewStreamCopier(sc zfs.StreamCopier) StreamCopier { + bsc := &streamCopier{sc, 0} + if scr, ok := sc.(io.Reader); ok { + return streamCopierAndReader{bsc, scr} + } else { + return bsc + } +} + +type streamCopier struct { + sc zfs.StreamCopier + count int64 +} + +// proxy writer used by streamCopier +type streamCopierWriter struct { + parent *streamCopier + w io.Writer +} + +func (w streamCopierWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + atomic.AddInt64(&w.parent.count, int64(n)) + return +} + +func (s *streamCopier) Count() int64 { + return atomic.LoadInt64(&s.count) +} + +var _ zfs.StreamCopier = &streamCopier{} + +func (s streamCopier) Close() error { + return s.sc.Close() +} + +func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + ww := streamCopierWriter{s, w} + return s.sc.WriteStreamTo(ww) +} + +// a streamCopier whose underlying sc is an io.Reader +type streamCopierAndReader struct { + *streamCopier + asReader io.Reader +} + +func (scr streamCopierAndReader) Read(p []byte) (int, error) { + n, err := scr.asReader.Read(p) + atomic.AddInt64(&scr.streamCopier.count, int64(n)) + return n, err +} diff --git a/util/bytecounter/bytecounter_streamcopier_test.go b/util/bytecounter/bytecounter_streamcopier_test.go new file mode 100644 index 0000000..29611e0 --- /dev/null +++ b/util/bytecounter/bytecounter_streamcopier_test.go @@ -0,0 +1,38 @@ +package bytecounter + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zrepl/zrepl/zfs" +) + +type mockStreamCopierAndReader struct { + zfs.StreamCopier // to satisfy interface + reads int +} + +func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) { + r.reads++ + return len(p), nil +} + +var _ io.Reader = &mockStreamCopierAndReader{} + +func TestNewStreamCopierReexportsReader(t *testing.T) { + mock := &mockStreamCopierAndReader{} + x := NewStreamCopier(mock) + + r, ok := x.(io.Reader) + if !ok { + t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x) + } + + var buf [23]byte + n, err := r.Read(buf[:]) + assert.True(t, mock.reads == 1) + assert.True(t, n == len(buf)) + assert.NoError(t, err) + assert.True(t, x.Count() == 23) +} diff --git a/util/devnoop/devnoop.go b/util/devnoop/devnoop.go new file mode 100644 index 0000000..9c4074f --- /dev/null +++ b/util/devnoop/devnoop.go @@ -0,0 +1,14 @@ +// package devnoop provides an io.ReadWriteCloser that never errors +// and always reports reads / writes to / from buffers as complete. +// The buffers themselves are never touched. +package devnoop + +type Dev struct{} + +func Get() Dev { + return Dev{} +} + +func (Dev) Write(p []byte) (n int, err error) { return len(p), nil } +func (Dev) Read(p []byte) (n int, err error) { return len(p), nil } +func (Dev) Close() error { return nil } diff --git a/util/envconst/envconst.go b/util/envconst/envconst.go index 8159aae..8c13190 100644 --- a/util/envconst/envconst.go +++ b/util/envconst/envconst.go @@ -2,6 +2,7 @@ package envconst import ( "os" + "strconv" "sync" "time" ) @@ -23,3 +24,19 @@ func Duration(varname string, def time.Duration) time.Duration { cache.Store(varname, d) return d } + +func Int64(varname string, def int64) int64 { + if v, ok := cache.Load(varname); ok { + return v.(int64) + } + e := os.Getenv(varname) + if e == "" { + return def + } + d, err := strconv.ParseInt(e, 10, 64) + if err != nil { + panic(err) + } + cache.Store(varname, d) + return d +} diff --git a/util/iocommand.go b/util/iocommand.go index 1113f74..fe5f9a3 100644 --- a/util/iocommand.go +++ b/util/iocommand.go @@ -4,15 +4,18 @@ import ( "bytes" "context" "fmt" + "github.com/zrepl/zrepl/util/envconst" "io" "os" "os/exec" "syscall" + "time" ) // An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. type IOCommand struct { Cmd *exec.Cmd + kill context.CancelFunc Stdin io.WriteCloser Stdout io.ReadCloser StderrBuf *bytes.Buffer @@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS c = &IOCommand{} + ctx, c.kill = context.WithCancel(ctx) c.Cmd = exec.CommandContext(ctx, command, args...) if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { @@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) { func (c *IOCommand) Read(buf []byte) (n int, err error) { n, err = c.Stdout.Read(buf) if err == io.EOF { - if waitErr := c.doWait(); waitErr != nil { + if waitErr := c.doWait(context.Background()); waitErr != nil { err = waitErr } } return } -func (c *IOCommand) doWait() (err error) { +func (c *IOCommand) doWait(ctx context.Context) (err error) { + go func() { + dl, ok := ctx.Deadline() + if !ok { + return + } + time.Sleep(dl.Sub(time.Now())) + c.kill() + c.Stdout.Close() + c.Stdin.Close() + }() waitErr := c.Cmd.Wait() var wasUs bool = false var waitStatus syscall.WaitStatus @@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) { if c.Cmd.ProcessState == nil { // racy... err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM) - if err != nil { - return - } - return c.doWait() + ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second)) + defer cancel() + return c.doWait(ctx) } else { return c.ExitResult.Error } diff --git a/util/socketpair/socketpair.go b/util/socketpair/socketpair.go index 615c8f2..c1da9e3 100644 --- a/util/socketpair/socketpair.go +++ b/util/socketpair/socketpair.go @@ -1,42 +1,32 @@ package socketpair import ( - "golang.org/x/sys/unix" "net" "os" + + "golang.org/x/sys/unix" ) -type fileConn struct { - net.Conn // net.FileConn - f *os.File -} -func (c fileConn) Close() error { - if err := c.Conn.Close(); err != nil { - return err - } - if err := c.f.Close(); err != nil { - return err - } - return nil -} - -func SocketPair() (a, b net.Conn, err error) { +func SocketPair() (a, b *net.UnixConn, err error) { // don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) if err != nil { return nil, nil, err } - toConn := func(fd int) (net.Conn, error) { + toConn := func(fd int) (*net.UnixConn, error) { f := os.NewFile(uintptr(fd), "fileconn") if f == nil { panic(fd) } c, err := net.FileConn(f) + f.Close() // net.FileConn uses dup under the hood if err != nil { - f.Close() return nil, err } - return fileConn{Conn: c, f: f}, nil + // strictly, the following type assertion is an implementation detail + // however, will be caught by test TestSocketPairWorks + fileConnIsUnixConn := c.(*net.UnixConn) + return fileConnIsUnixConn, nil } if a, err = toConn(sockpair[0]); err != nil { // shadowing return nil, nil, err diff --git a/util/socketpair/socketpair_test.go b/util/socketpair/socketpair_test.go new file mode 100644 index 0000000..95cbdc5 --- /dev/null +++ b/util/socketpair/socketpair_test.go @@ -0,0 +1,18 @@ +package socketpair + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// This is test is mostly to verify that the assumption about +// net.FileConn returning *net.UnixConn for AF_UNIX FDs works. +func TestSocketPairWorks(t *testing.T) { + assert.NotPanics(t, func() { + a, b, err := SocketPair() + assert.NoError(t, err) + a.Close() + b.Close() + }) +} diff --git a/zfs/diff.go b/zfs/diff.go index 2b37f6d..52eb84f 100644 --- a/zfs/diff.go +++ b/zfs/diff.go @@ -274,7 +274,7 @@ func ZFSCreatePlaceholderFilesystem(p *DatasetPath) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } diff --git a/zfs/mapping.go b/zfs/mapping.go index 56a85b3..43bb5bb 100644 --- a/zfs/mapping.go +++ b/zfs/mapping.go @@ -9,8 +9,8 @@ type DatasetFilter interface { Filter(p *DatasetPath) (pass bool, err error) } -func ZFSListMapping(filter DatasetFilter) (datasets []*DatasetPath, err error) { - res, err := ZFSListMappingProperties(filter, nil) +func ZFSListMapping(ctx context.Context, filter DatasetFilter) (datasets []*DatasetPath, err error) { + res, err := ZFSListMappingProperties(ctx, filter, nil) if err != nil { return nil, err } @@ -28,7 +28,7 @@ type ZFSListMappingPropertiesResult struct { } // properties must not contain 'name' -func ZFSListMappingProperties(filter DatasetFilter, properties []string) (datasets []ZFSListMappingPropertiesResult, err error) { +func ZFSListMappingProperties(ctx context.Context, filter DatasetFilter, properties []string) (datasets []ZFSListMappingPropertiesResult, err error) { if filter == nil { panic("filter must not be nil") @@ -44,7 +44,7 @@ func ZFSListMappingProperties(filter DatasetFilter, properties []string) (datase copy(newProps[1:], properties) properties = newProps - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() rchan := make(chan ZFSListResult) diff --git a/zfs/zfs.go b/zfs/zfs.go index 56c9c77..8f08a66 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -7,15 +7,22 @@ import ( "errors" "fmt" "io" + "os" "os/exec" "strings" + "sync" + "time" "context" - "github.com/problame/go-rwccmd" "github.com/prometheus/client_golang/prometheus" - "github.com/zrepl/zrepl/util" "regexp" "strconv" + "github.com/zrepl/zrepl/util/envconst" +) + +var ( + ZFSSendPipeCapacityHint = int(envconst.Int64("ZFS_SEND_PIPE_CAPACITY_HINT", 1<<25)) + ZFSRecvPipeCapacityHint = int(envconst.Int64("ZFS_RECV_PIPE_CAPACITY_HINT", 1<<25)) ) type DatasetPath struct { @@ -141,7 +148,7 @@ type ZFSError struct { WaitErr error } -func (e ZFSError) Error() string { +func (e *ZFSError) Error() string { return fmt.Sprintf("zfs exited with error: %s\nstderr:\n%s", e.WaitErr.Error(), e.Stderr) } @@ -187,7 +194,7 @@ func ZFSList(properties []string, zfsArgs ...string) (res [][]string, err error) } if waitErr := cmd.Wait(); waitErr != nil { - err := ZFSError{ + err := &ZFSError{ Stderr: stderr.Bytes(), WaitErr: waitErr, } @@ -227,18 +234,24 @@ func ZFSListChan(ctx context.Context, out chan ZFSListResult, properties []strin } } - cmd, err := rwccmd.CommandContext(ctx, ZFS_BINARY, args, []string{}) + cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) + stdout, err := cmd.StdoutPipe() if err != nil { sendResult(nil, err) return } + // TODO bounded buffer + stderr := bytes.NewBuffer(make([]byte, 0, 1024)) + cmd.Stderr = stderr if err = cmd.Start(); err != nil { sendResult(nil, err) return } - defer cmd.Close() + defer func() { + cmd.Wait() + }() - s := bufio.NewScanner(cmd) + s := bufio.NewScanner(stdout) buf := make([]byte, 1024) // max line length s.Buffer(buf, 0) @@ -252,8 +265,20 @@ func ZFSListChan(ctx context.Context, out chan ZFSListResult, properties []strin return } } + if err := cmd.Wait(); err != nil { + if err, ok := err.(*exec.ExitError); ok { + sendResult(nil, &ZFSError{ + Stderr: stderr.Bytes(), + WaitErr: err, + }) + } else { + sendResult(nil, &ZFSError{WaitErr: err}) + } + return + } if s.Err() != nil { sendResult(nil, s.Err()) + return } return } @@ -314,10 +339,180 @@ func buildCommonSendArgs(fs string, from, to string, token string) ([]string, er return args, nil } +type sendStreamCopier struct { + recorder readErrRecorder +} + +type readErrRecorder struct { + io.ReadCloser + readErr error +} + +type sendStreamCopierError struct { + isReadErr bool // if false, it's a write error + err error +} + +func (e sendStreamCopierError) Error() string { + if e.isReadErr { + return fmt.Sprintf("stream: read error: %s", e.err) + } else { + return fmt.Sprintf("stream: writer error: %s", e.err) + } +} + +func (e sendStreamCopierError) IsReadError() bool { return e.isReadErr } +func (e sendStreamCopierError) IsWriteError() bool { return !e.isReadErr } + +func (r *readErrRecorder) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + r.readErr = err + return n, err +} + +func newSendStreamCopier(stream io.ReadCloser) *sendStreamCopier { + return &sendStreamCopier{recorder: readErrRecorder{stream, nil}} +} + +func (c *sendStreamCopier) WriteStreamTo(w io.Writer) StreamCopierError { + debug("sendStreamCopier.WriteStreamTo: begin") + _, err := io.Copy(w, &c.recorder) + debug("sendStreamCopier.WriteStreamTo: copy done") + if err != nil { + if c.recorder.readErr != nil { + return sendStreamCopierError{isReadErr: true, err: c.recorder.readErr} + } else { + return sendStreamCopierError{isReadErr: false, err: err} + } + } + return nil +} + +func (c *sendStreamCopier) Read(p []byte) (n int, err error) { + return c.recorder.Read(p) +} + +func (c *sendStreamCopier) Close() error { + return c.recorder.ReadCloser.Close() +} + +func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { + if capacity <= 0 { + panic(fmt.Sprintf("capacity must be positive %v", capacity)) + } + stdoutReader, stdoutWriter, err := os.Pipe() + if err != nil { + return nil, nil, err + } + trySetPipeCapacity(stdoutWriter, capacity) + return stdoutReader, stdoutWriter, nil +} + +type sendStream struct { + cmd *exec.Cmd + kill context.CancelFunc + + closeMtx sync.Mutex + stdoutReader *os.File + opErr error + +} + +func (s *sendStream) Read(p []byte) (n int, err error) { + s.closeMtx.Lock() + opErr := s.opErr + s.closeMtx.Unlock() + if opErr != nil { + return 0, opErr + } + + n, err = s.stdoutReader.Read(p) + if err != nil { + debug("sendStream: read err: %T %s", err, err) + // TODO we assume here that any read error is permanent + // which is most likely the case for a local zfs send + kwerr := s.killAndWait(err) + debug("sendStream: killAndWait n=%v err= %T %s", n, kwerr, kwerr) + // TODO we assume here that any read error is permanent + return n, kwerr + } + return n, err +} + +func (s *sendStream) Close() error { + debug("sendStream: close called") + return s.killAndWait(nil) +} + +func (s *sendStream) killAndWait(precedingReadErr error) error { + + debug("sendStream: killAndWait enter") + defer debug("sendStream: killAndWait leave") + if precedingReadErr == io.EOF { + // give the zfs process a little bit of time to terminate itself + // if it holds this deadline, exitErr will be nil + time.AfterFunc(200*time.Millisecond, s.kill) + } else { + s.kill() + } + + // allow async kills from Close(), that's why we only take the mutex here + s.closeMtx.Lock() + defer s.closeMtx.Unlock() + + if s.opErr != nil { + return s.opErr + } + + waitErr := s.cmd.Wait() + // distinguish between ExitError (which is actually a non-problem for us) + // vs failed wait syscall (for which we give upper layers the chance to retyr) + var exitErr *exec.ExitError + if waitErr != nil { + if ee, ok := waitErr.(*exec.ExitError); ok { + exitErr = ee + } else { + return waitErr + } + } + + // now, after we know the program exited do we close the pipe + var closePipeErr error + if s.stdoutReader != nil { + closePipeErr = s.stdoutReader.Close() + if closePipeErr == nil { + // avoid double-closes in case anything below doesn't work + // and someone calls Close again + s.stdoutReader = nil + } else { + return closePipeErr + } + } + + // we managed to tear things down, no let's give the user some pretty *ZFSError + if exitErr != nil { + s.opErr = &ZFSError{ + Stderr: exitErr.Stderr, + WaitErr: exitErr, + } + } else { + s.opErr = fmt.Errorf("zfs send exited with status code 0") + } + + // detect the edge where we're called from s.Read + // after the pipe EOFed and zfs send exited without errors + // this is actullay the "hot" / nice path + if exitErr == nil && precedingReadErr == io.EOF { + return precedingReadErr + } + + return s.opErr +} + // if token != "", then send -t token is used // otherwise send [-i from] to is used // (if from is "" a full ZFS send is done) -func ZFSSend(ctx context.Context, fs string, from, to string, token string) (stream io.ReadCloser, err error) { +func ZFSSend(ctx context.Context, fs string, from, to string, token string) (streamCopier StreamCopier, err error) { args := make([]string, 0) args = append(args, "send") @@ -328,9 +523,33 @@ func ZFSSend(ctx context.Context, fs string, from, to string, token string) (str } args = append(args, sargs...) - stream, err = util.RunIOCommand(ctx, ZFS_BINARY, args...) + ctx, cancel := context.WithCancel(ctx) + cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) - return + // setup stdout with an os.Pipe to control pipe buffer size + stdoutReader, stdoutWriter, err := pipeWithCapacityHint(ZFSSendPipeCapacityHint) + if err != nil { + cancel() + return nil, err + } + + cmd.Stdout = stdoutWriter + + if err := cmd.Start(); err != nil { + cancel() + stdoutWriter.Close() + stdoutReader.Close() + return nil, err + } + stdoutWriter.Close() + + stream := &sendStream{ + cmd: cmd, + kill: cancel, + stdoutReader: stdoutReader, + } + + return newSendStreamCopier(stream), err } @@ -454,8 +673,26 @@ func ZFSSendDry(fs string, from, to string, token string) (_ *DrySendInfo, err e return &si, nil } +type StreamCopierError interface { + error + IsReadError() bool + IsWriteError() bool +} -func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs ...string) (err error) { +type StreamCopier interface { + // WriteStreamTo writes the stream represented by this StreamCopier + // to the given io.Writer. + WriteStreamTo(w io.Writer) StreamCopierError + // Close must be called as soon as it is clear that no more data will + // be read from the StreamCopier. + // If StreamCopier gets its data from a connection, it might hold + // a lock on the connection until Close is called. Only closing ensures + // that the connection can be used afterwards. + Close() error +} + + +func ZFSRecv(ctx context.Context, fs string, streamCopier StreamCopier, additionalArgs ...string) (err error) { if err := validateZFSFilesystem(fs); err != nil { return err @@ -468,6 +705,8 @@ func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs .. } args = append(args, fs) + ctx, cancelCmd := context.WithCancel(ctx) + defer cancelCmd() cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) stderr := bytes.NewBuffer(make([]byte, 0, 1024)) @@ -480,21 +719,60 @@ func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs .. stdout := bytes.NewBuffer(make([]byte, 0, 1024)) cmd.Stdout = stdout - cmd.Stdin = stream + stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint) + if err != nil { + return err + } + + cmd.Stdin = stdin if err = cmd.Start(); err != nil { - return + stdinWriter.Close() + stdin.Close() + return err + } + stdin.Close() + defer stdinWriter.Close() + + pid := cmd.Process.Pid + debug := func(format string, args ...interface{}) { + debug("recv: pid=%v: %s", pid, fmt.Sprintf(format, args...)) } - if err = cmd.Wait(); err != nil { - err = ZFSError{ - Stderr: stderr.Bytes(), - WaitErr: err, + debug("started") + + copierErrChan := make(chan StreamCopierError) + go func() { + copierErrChan <- streamCopier.WriteStreamTo(stdinWriter) + }() + waitErrChan := make(chan *ZFSError) + go func() { + defer close(waitErrChan) + if err = cmd.Wait(); err != nil { + waitErrChan <- &ZFSError{ + Stderr: stderr.Bytes(), + WaitErr: err, + } + return } - return + }() + + // streamCopier always fails before or simultaneously with Wait + // thus receive from it first + copierErr := <-copierErrChan + debug("copierErr: %T %s", copierErr, copierErr) + if copierErr != nil { + cancelCmd() } - return nil + waitErr := <-waitErrChan + debug("waitErr: %T %s", waitErr, waitErr) + if copierErr == nil && waitErr == nil { + return nil + } else if waitErr != nil && (copierErr == nil || copierErr.IsWriteError()) { + return waitErr // has more interesting info in that case + } + return copierErr // if it's not a write error, the copier error is more interesting } type ClearResumeTokenError struct { @@ -572,7 +850,7 @@ func zfsSet(path string, props *ZFSProperties) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -689,7 +967,7 @@ func ZFSDestroy(dataset string) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -723,7 +1001,7 @@ func ZFSSnapshot(fs *DatasetPath, name string, recursive bool) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -751,7 +1029,7 @@ func ZFSBookmark(fs *DatasetPath, snapshot, bookmark string) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } diff --git a/zfs/zfs_debug.go b/zfs/zfs_debug.go new file mode 100644 index 0000000..32846e4 --- /dev/null +++ b/zfs/zfs_debug.go @@ -0,0 +1,20 @@ +package zfs + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_ZFS_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "zfs: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/zfs/zfs_pipe.go b/zfs/zfs_pipe.go new file mode 100644 index 0000000..4816441 --- /dev/null +++ b/zfs/zfs_pipe.go @@ -0,0 +1,18 @@ +// +build !linux + +package zfs + +import ( + "os" + "sync" +) + +var zfsPipeCapacityNotSupported sync.Once + +func trySetPipeCapacity(p *os.File, capacity int) { + if debugEnabled { + zfsPipeCapacityNotSupported.Do(func() { + debug("trySetPipeCapacity error: OS does not support setting pipe capacity") + }) + } +} diff --git a/zfs/zfs_pipe_linux.go b/zfs/zfs_pipe_linux.go new file mode 100644 index 0000000..ab50020 --- /dev/null +++ b/zfs/zfs_pipe_linux.go @@ -0,0 +1,21 @@ +package zfs + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +func trySetPipeCapacity(p *os.File, capacity int) { + res, err := unix.FcntlInt(p.Fd(), unix.F_SETPIPE_SZ, capacity) + if err != nil { + err = fmt.Errorf("cannot set pipe capacity to %v", capacity) + } else if res == -1 { + err = errors.New("cannot set pipe capacity: fcntl returned -1") + } + if debugEnabled && err != nil { + debug("trySetPipeCapacity error: %s\n", err) + } +} diff --git a/zfs/zfs_test.go b/zfs/zfs_test.go index 6ffdfdf..9126bb0 100644 --- a/zfs/zfs_test.go +++ b/zfs/zfs_test.go @@ -15,7 +15,7 @@ func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) { _, err = ZFSList([]string{"fictionalprop"}, "nonexistent/dataset") assert.Error(t, err) - zfsError, ok := err.(ZFSError) + zfsError, ok := err.(*ZFSError) assert.True(t, ok) assert.Equal(t, "error: this is a mock\n", string(zfsError.Stderr)) }